diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index e0af9344588..df87f35a69d 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -4361747abfc55e40e929396ed986efe775d745f9 +d03e90c2cd9048e6d9a75285c0355f033cd016fc diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index a8de771a69d..61b7ab7807e 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -e6f766c7d750d40603eee3f66c5915bac606b3ea +b31bad1b8f1331bf43d47f46602cf6141db56844 diff --git a/.ci/docker/common/install_arm.sh b/.ci/docker/common/install_arm.sh new file mode 100644 index 00000000000..dec8a1693ee --- /dev/null +++ b/.ci/docker/common/install_arm.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +install_arm_prerequiresites() { + apt-get update -y + apt-get install -y --no-install-recommends \ + mesa-vulkan-drivers libvulkan1 + rm -rf /var/lib/apt/lists/* +} + +install_arm_prerequiresites diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index d16b91cc7a3..3ce3bda50fa 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -30,7 +30,6 @@ sphinx-reredirects==0.1.4 matplotlib>=3.9.4 sphinx-copybutton==0.5.2 # PyTorch Theme --e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 - +pytorch_sphinx_theme2==0.2.0 # script unit test requirements yaspin==3.1.0 diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index fddd7e6df36..9c57e5ee951 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -83,6 +83,9 @@ RUN if [ -n "${ANDROID_NDK_VERSION}" ]; then bash ./install_android.sh; fi RUN rm install_android.sh ARG ARM_SDK +COPY ./common/install_arm.sh install_arm.sh +RUN if [ -n "${ARM_SDK}" ]; then bash ./install_arm.sh; fi +RUN rm install_arm.sh ARG ZEPHYR_SDK COPY ./common/install_zephyr.sh install_zephyr.sh diff --git a/.ci/scripts/cuda_benchmark.py b/.ci/scripts/cuda_benchmark.py new file mode 100644 index 00000000000..b135925d4b4 --- /dev/null +++ b/.ci/scripts/cuda_benchmark.py @@ -0,0 +1,939 @@ +""" +Benchmark script for CUDA model runners. +Runs model runner commands multiple times and collects performance metrics. +Supports whisper, voxtral, gemma3, and other CUDA models. +""" + +import argparse +import json +import statistics +import subprocess +import sys +from dataclasses import dataclass +from typing import List, Optional, Tuple + + +@dataclass +class RunMetrics: + """Metrics from a single run.""" + + generated_tokens: int + tokens_per_sec: float + model_load_time_ms: float + total_inference_time_ms: float + encoder_time_ms: float + generation_time_ms: float + first_token_latency_ms: float + + def __repr__(self): + return ( + f"Tokens: {self.generated_tokens}, " + f"Throughput: {self.tokens_per_sec:.2f} t/s, " + f"Model load: {self.model_load_time_ms:.0f}ms, " + f"Total inference: {self.total_inference_time_ms:.0f}ms, " + f"Encoder: {self.encoder_time_ms:.0f}ms, " + f"Generation: {self.generation_time_ms:.0f}ms, " + f"First token: {self.first_token_latency_ms:.0f}ms" + ) + + +def parse_pytorch_observer_log(log_line: str) -> Optional[RunMetrics]: + """Parse PyTorchObserver JSON output and compute metrics.""" + try: + # Find the JSON part in the log line + if "PyTorchObserver" not in log_line: + return None + + json_str = log_line.split("PyTorchObserver")[1].strip() + data = json.loads(json_str) + + # Extract values + generated_tokens = data.get("generated_tokens", 0) + inference_start_ms = data.get("inference_start_ms", 0) + inference_end_ms = data.get("inference_end_ms", 0) + prompt_eval_end_ms = data.get("prompt_eval_end_ms", 0) + first_token_ms = data.get("first_token_ms", 0) + model_load_start_ms = data.get("model_load_start_ms", 0) + model_load_end_ms = data.get("model_load_end_ms", 0) + + # Compute metrics + # Total inference time: from inference start to inference end + total_inference_time_ms = inference_end_ms - inference_start_ms + + # Encoder time: from inference start to prompt evaluation end + encoder_time_ms = prompt_eval_end_ms - inference_start_ms + + # Generation time: from prompt evaluation end to inference end + generation_time_ms = inference_end_ms - prompt_eval_end_ms + + # Calculate throughput based on generation time + tokens_per_sec = ( + (generated_tokens / generation_time_ms * 1000) + if generation_time_ms > 0 + else 0 + ) + model_load_time_ms = model_load_end_ms - model_load_start_ms + first_token_latency_ms = first_token_ms - prompt_eval_end_ms + + return RunMetrics( + generated_tokens=generated_tokens, + tokens_per_sec=tokens_per_sec, + model_load_time_ms=model_load_time_ms, + total_inference_time_ms=total_inference_time_ms, + encoder_time_ms=encoder_time_ms, + generation_time_ms=generation_time_ms, + first_token_latency_ms=first_token_latency_ms, + ) + except (json.JSONDecodeError, KeyError, ValueError) as e: + print(f"Error parsing PyTorchObserver log: {e}", file=sys.stderr) + return None + + +def get_gpu_clocks() -> Optional[Tuple[str, str]]: + """Get current GPU and memory clock frequencies.""" + try: + # Get GPU clock + result_gpu = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.gr", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + # Get memory clock + result_mem = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.mem", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + + if result_gpu.returncode == 0 and result_mem.returncode == 0: + gpu_clock = result_gpu.stdout.strip().split("\n")[0] + mem_clock = result_mem.stdout.strip().split("\n")[0] + return gpu_clock, mem_clock + except Exception as e: + print(f"Warning: Failed to get GPU clocks: {e}", file=sys.stderr) + return None + + +def set_gpu_clocks(gpu_clock: Optional[int] = None) -> bool: + """ + Set GPU clock frequency to a fixed value. + + Args: + gpu_clock: Target GPU clock frequency in MHz. + If None, will use max available. + + Returns: + True if successful, False otherwise + """ + try: + print("\n[GPU Clock Setup] Fixing GPU clock frequency...") + + # Enable persistence mode + result = subprocess.run( + ["sudo", "nvidia-smi", "-pm", "1"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + f"Warning: Failed to enable persistence mode: {result.stderr}", + file=sys.stderr, + ) + return False + print("✓ Enabled persistence mode") + + # Lock GPU clocks + if gpu_clock is None: + # Get max GPU clock + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=clocks.max.gr", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + gpu_clock = int(result.stdout.strip().split("\n")[0]) + print(f"✓ Detected max GPU clock: {gpu_clock} MHz") + + # Lock GPU clock to the target frequency + result = subprocess.run( + ["sudo", "nvidia-smi", "-lgc", f"{gpu_clock},{gpu_clock}"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + f"Warning: Failed to lock GPU clock: {result.stderr}", + file=sys.stderr, + ) + return False + + print(f"✓ Locked GPU clock to {gpu_clock} MHz") + return True + + except Exception as e: + print(f"Error: Failed to set GPU clocks: {e}", file=sys.stderr) + return False + + +def reset_gpu_clocks() -> bool: + """Reset GPU clock frequencies to default.""" + try: + print("\n[GPU Clock Cleanup] Resetting GPU clock frequency...") + + # Reset GPU clocks + result = subprocess.run( + ["sudo", "nvidia-smi", "-rgc"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + f"Warning: Failed to reset GPU clock: {result.stderr}", + file=sys.stderr, + ) + return False + print("✓ Reset GPU clock to default") + + # Disable persistence mode + result = subprocess.run( + ["sudo", "nvidia-smi", "-pm", "0"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode != 0: + print( + "Warning: Failed to disable persistence mode: " f"{result.stderr}", + file=sys.stderr, + ) + return False + print("✓ Disabled persistence mode") + + return True + + except Exception as e: + print(f"Error: Failed to reset GPU clocks: {e}", file=sys.stderr) + return False + + +def _print_warmup_info(warmup_runs: int) -> None: + """Print warmup phase information.""" + if warmup_runs > 0: + print(f"\n{'='*70}") + print(f"WARMUP PHASE: Running {warmup_runs} warmup iterations...") + print(f"{'='*70}") + + +def _print_benchmark_info( + actual_benchmark_runs: int, trim_count: int, num_runs: int +) -> None: + """Print benchmark phase information.""" + print(f"\n{'='*70}") + print(f"BENCHMARK PHASE: Running {actual_benchmark_runs} iterations") + print(f"Will trim top and bottom {trim_count} results (10% of {num_runs})") + print(f"Final statistics will be based on middle {num_runs} results") + print(f"{'='*70}") + + +def _run_single_iteration( + command: str, run_num: int, verbose: bool +) -> Optional[RunMetrics]: + """ + Run a single benchmark iteration and return metrics. + + Args: + command: Command to execute + run_num: Current run number + verbose: Print verbose output + + Returns: + RunMetrics if successful, None otherwise + """ + try: + # Run command and capture output + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + + if result.returncode != 0: + print( + f"Error: Command failed with return code {result.returncode}", + file=sys.stderr, + ) + if result.stderr: + print(f"stderr: {result.stderr}", file=sys.stderr) + return None + + # Search for PyTorchObserver line in output + observer_line = None + for line in result.stdout.split("\n"): + if "PyTorchObserver" in line: + observer_line = line + break + + if observer_line is None: + print( + f"Warning: No PyTorchObserver output found in run {run_num}", + file=sys.stderr, + ) + if verbose: + print(f"stdout:\n{result.stdout}", file=sys.stderr) + return None + + # Parse and return metrics + metrics = parse_pytorch_observer_log(observer_line) + if metrics is None: + print( + f"Warning: Failed to parse metrics from run {run_num}", + file=sys.stderr, + ) + return None + + print(f"✓ {metrics}") + return metrics + + except subprocess.TimeoutExpired: + print(f"Error: Command timed out on run {run_num}", file=sys.stderr) + return None + except Exception as e: + print(f"Error on run {run_num}: {e}", file=sys.stderr) + return None + + +def run_model_benchmark( + command: str, + num_runs: int = 5, + warmup_runs: int = 0, + verbose: bool = False, +) -> List[RunMetrics]: + """ + Run the model runner command multiple times and collect metrics. + + For trimmed mean calculation, this function runs extra iterations + to ensure we can trim outliers. Based on num_runs, we calculate + trim_count = num_runs * 0.1, then run num_runs + 2*trim_count total + iterations. The top and bottom trim_count results will be discarded. + + Args: + command: Full command to run + num_runs: Number of benchmark runs requested by user (after trim) + warmup_runs: Number of warmup runs (results will be discarded) + verbose: Print detailed output + + Returns: + List of RunMetrics from benchmark runs (excluding warmup). + """ + # Calculate trim count and total runs + trim_count = int(num_runs * 0.1) + actual_benchmark_runs = num_runs + 2 * trim_count + total_runs = warmup_runs + actual_benchmark_runs + + # Print phase information + _print_warmup_info(warmup_runs) + _print_benchmark_info(actual_benchmark_runs, trim_count, num_runs) + + # Execute all runs + results = [] + for run_num in range(1, total_runs + 1): + is_warmup = run_num <= warmup_runs + phase = "Warmup" if is_warmup else "Benchmark" + benchmark_run_num = run_num - warmup_runs if not is_warmup else run_num + + # Print run header + if is_warmup: + print(f"\n[{phase} {run_num}/{warmup_runs}] Executing: {command}") + else: + print( + f"\n[{phase} {benchmark_run_num}/{actual_benchmark_runs}] " + f"Executing: {command}" + ) + + # Run iteration and collect metrics + metrics = _run_single_iteration(command, run_num, verbose) + if metrics is not None and not is_warmup: + results.append(metrics) + + return results + + +def calculate_trimmed_stats( + values: List[float], trim_count: int +) -> Tuple[List[float], float, float, float, float]: + """ + Calculate statistics on trimmed data. + + Args: + values: List of numeric values + trim_count: Number of values to trim from each end + + Returns: + Tuple of (trimmed_values, min, max, mean, stdev) + """ + if not values: + return [], 0.0, 0.0, 0.0, 0.0 + + # Sort values + sorted_values = sorted(values) + n = len(sorted_values) + + # Trim if we have enough data and trim_count > 0 + if trim_count > 0 and n > 2 * trim_count: + trimmed_values = sorted_values[trim_count : n - trim_count] + else: + trimmed_values = sorted_values + + # Calculate stats on trimmed data + min_val = min(trimmed_values) + max_val = max(trimmed_values) + mean_val = statistics.mean(trimmed_values) + stdev_val = statistics.stdev(trimmed_values) if len(trimmed_values) > 1 else 0.0 + + return trimmed_values, min_val, max_val, mean_val, stdev_val + + +@dataclass +class MetricStats: + """Statistics for a single metric with operations.""" + + name: str + mean: float + min_val: float + max_val: float + stdev: float + unit: str = "" + extra_info: dict | None = None + + def create_v3_record( + self, + model_name: str, + backend: str, + runner_name: str, + runner_type: str, + base_extra_info: dict, + ) -> dict: + """ + Create a v3 format record for this metric. + + Args: + model_name: Model name with quantization + backend: Backend name (e.g., "cuda-aoti") + runner_name: GPU device name + runner_type: CUDA driver version + base_extra_info: Base extra_info dict to copy + + Returns: + Complete v3 format metric record + """ + extra_stats = { + "min": self.min_val, + "max": self.max_val, + "stdev": self.stdev, + } + if self.extra_info: + extra_stats.update(self.extra_info) + + return { + "benchmark": { + "name": "ExecuTorch", + "mode": "inference", + "extra_info": base_extra_info.copy(), + }, + "model": { + "name": model_name, + "type": "OSS model", + "backend": backend, + }, + "metric": { + "name": self.name, + "benchmark_values": [self.mean], + "target_value": 0, + "extra_info": extra_stats, + }, + "runners": [{"name": runner_name, "type": runner_type}], + } + + def print_stats(self) -> None: + """Print formatted statistics for this metric.""" + # Determine precision based on metric type + is_throughput = "tokens" in self.name.lower() + precision = 2 if is_throughput else 0 + + # Format metric name for display + display_name = self.name.replace("_", " ").upper() + if self.unit: + display_name = f"{display_name} ({self.unit})" + + print(f"{display_name}:") + print(f" Min: {self.min_val:.{precision}f} {self.unit}") + print(f" Max: {self.max_val:.{precision}f} {self.unit}") + print(f" Mean: {self.mean:.{precision}f} {self.unit}") + print(f" Stdev: {self.stdev:.{precision}f} {self.unit}") + print() + + +@dataclass +class BenchmarkResults: + """Summary of benchmark results.""" + + model_name: str + total_runs: int + trimmed_runs: int + discarded_runs: int + generated_tokens: int + + # Metrics + throughput: MetricStats + model_load_time: MetricStats + total_inference_time: MetricStats + encoder_time: MetricStats + generation_time: MetricStats + first_token_latency: MetricStats + + def save_json(self, output_path: str) -> None: + """Save results to JSON file.""" + with open(output_path, "w") as f: + json.dump(self.to_dict(), f, indent=2) + print(f"\n✓ Results saved to: {output_path}") + + def to_dict(self) -> dict: + """Convert results to dictionary for JSON serialization.""" + return { + "model_name": self.model_name, + "total_runs": self.total_runs, + "trimmed_runs": self.trimmed_runs, + "discarded_runs": self.discarded_runs, + "generated_tokens": self.generated_tokens, + "throughput_mean": self.throughput.mean, + "throughput_min": self.throughput.min_val, + "throughput_max": self.throughput.max_val, + "throughput_stdev": self.throughput.stdev, + "model_load_time_mean": self.model_load_time.mean, + "model_load_time_min": self.model_load_time.min_val, + "model_load_time_max": self.model_load_time.max_val, + "model_load_time_stdev": self.model_load_time.stdev, + "total_inference_time_mean": self.total_inference_time.mean, + "total_inference_time_min": self.total_inference_time.min_val, + "total_inference_time_max": self.total_inference_time.max_val, + "total_inference_time_stdev": self.total_inference_time.stdev, + "encoder_time_mean": self.encoder_time.mean, + "encoder_time_min": self.encoder_time.min_val, + "encoder_time_max": self.encoder_time.max_val, + "encoder_time_stdev": self.encoder_time.stdev, + "generation_time_mean": self.generation_time.mean, + "generation_time_min": self.generation_time.min_val, + "generation_time_max": self.generation_time.max_val, + "generation_time_stdev": self.generation_time.stdev, + "first_token_latency_mean": self.first_token_latency.mean, + "first_token_latency_min": self.first_token_latency.min_val, + "first_token_latency_max": self.first_token_latency.max_val, + "first_token_latency_stdev": self.first_token_latency.stdev, + } + + def to_v3_format( + self, + model: str, + quantization: str, + git_sha: str, + workflow_run_id: str, + workflow_run_url: str = "", + gpu_name: str = "CUDA", + cuda_driver_version: str = "cuda", + ) -> List[dict]: + """ + Transform benchmark results to PyTorch benchmark database v3 format. + + Args: + model: Model name (e.g., "openai/whisper-small") + quantization: Quantization type (e.g., "non-quantized") + git_sha: Git commit SHA + workflow_run_id: GitHub workflow run ID + workflow_run_url: GitHub workflow run URL + gpu_name: GPU device name (e.g., "Tesla V100", "A100") + cuda_driver_version: CUDA driver version (e.g., "12.6", "535.104.05") + + Returns: + List of benchmark records in v3 format + """ + # Shared configuration + model_name_with_quant = f"{model}_{quantization}" + backend = "cuda-aoti" + runner_name = gpu_name + runner_type = cuda_driver_version + + # Create base extra_info + base_extra_info = { + "backend": "cuda", + "quantization": quantization, + "git_sha": git_sha, + "workflow_run_id": workflow_run_id, + } + if workflow_run_url: + base_extra_info["workflow_run_url"] = workflow_run_url + + # Create v3 records for all metrics + return [ + self.throughput.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.model_load_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.total_inference_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.encoder_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.generation_time.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + self.first_token_latency.create_v3_record( + model_name_with_quant, + backend, + runner_name, + runner_type, + base_extra_info, + ), + ] + + +def compute_summary( + model_name: str, results: List[RunMetrics], requested_runs: int +) -> BenchmarkResults: + """ + Compute summary statistics using trimmed data. + + All statistics (min, max, mean, stdev) are calculated based on + the trimmed dataset after removing outliers. + + Args: + model_name: Name of the model being benchmarked + results: List of all collected run metrics + requested_runs: Number of runs originally requested by user + + Returns: + BenchmarkResults object with all computed statistics + """ + if not results: + raise ValueError("No valid results to summarize.") + + # Calculate trim count based on requested runs (not actual runs) + trim_count = int(requested_runs * 0.1) + + # Helper to create MetricStats from values + def create_metric_stats( + name: str, values: List[float], unit: str = "", extra_info: dict | None = None + ) -> MetricStats: + _, min_val, max_val, mean_val, stdev_val = calculate_trimmed_stats( + values, trim_count + ) + return MetricStats( + name=name, + mean=mean_val, + min_val=min_val, + max_val=max_val, + stdev=stdev_val, + unit=unit, + extra_info=extra_info, + ) + + # Get the first trimmed result to get trimmed_runs count + trimmed_throughput, _, _, _, _ = calculate_trimmed_stats( + [r.tokens_per_sec for r in results], trim_count + ) + + return BenchmarkResults( + model_name=model_name, + total_runs=len(results), + trimmed_runs=len(trimmed_throughput), + discarded_runs=trim_count * 2, + generated_tokens=results[0].generated_tokens, + throughput=create_metric_stats( + "throughput(tokens/sec)", + [r.tokens_per_sec for r in results], + "t/s", + {"trimmed_runs": len(trimmed_throughput)}, + ), + model_load_time=create_metric_stats( + "model_load_time(ms)", + [r.model_load_time_ms for r in results], + "ms", + ), + total_inference_time=create_metric_stats( + "total_inference_time(ms)", + [r.total_inference_time_ms for r in results], + "ms", + ), + encoder_time=create_metric_stats( + "encoder_time(ms)", + [r.encoder_time_ms for r in results], + "ms", + ), + generation_time=create_metric_stats( + "generation_time(ms)", + [r.generation_time_ms for r in results], + "ms", + ), + first_token_latency=create_metric_stats( + "first_token_latency(ms)", + [r.first_token_latency_ms for r in results], + "ms", + ), + ) + + +def print_summary(summary: BenchmarkResults) -> None: + """Print formatted summary of benchmark results.""" + print("\n" + "=" * 70) + print(f"BENCHMARK SUMMARY for model: {summary.model_name}") + print("=" * 70) + print(f"Total runs collected: {summary.total_runs}") + print(f"Trimmed to: {summary.trimmed_runs} runs") + print( + f"(Discarded {summary.discarded_runs // 2} highest and " + f"{summary.discarded_runs // 2} lowest results)" + ) + print(f"Generated tokens per run: {summary.generated_tokens}") + print() + + # Print all metrics using their print_stats method + summary.throughput.print_stats() + summary.model_load_time.print_stats() + summary.total_inference_time.print_stats() + summary.encoder_time.print_stats() + summary.generation_time.print_stats() + summary.first_token_latency.print_stats() + + print("=" * 70) + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser( + description="Benchmark CUDA model runners and collect performance metrics" + ) + parser.add_argument( + "--runner_command", + type=str, + required=True, + help="Full command to run the model runner", + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Name of the model being benchmarked", + ) + parser.add_argument( + "--num_runs", + type=int, + default=50, + help="Number of benchmark runs (default: 50)", + ) + parser.add_argument( + "--warmup_runs", + type=int, + default=0, + help="Number of warmup runs before benchmark (default: 0.1 * num_runs)", + ) + parser.add_argument( + "--fix_gpu_clock", + type=bool, + default=True, + help="Fix GPU clock frequency to maximum before benchmarking", + ) + parser.add_argument( + "--gpu_clock", + type=int, + default=None, + help="Target GPU clock frequency in MHz (requires " + "--fix_gpu_clock). If not specified, uses max available.", + ) + parser.add_argument( + "--output_json", + type=str, + default=None, + help="Path to save JSON results", + ) + parser.add_argument( + "--output_v3", + type=str, + default=None, + help="Path to save v3 format JSON results for dashboard", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model ID (e.g., 'openai/whisper-small') - required for v3 format", + ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help="Quantization type (e.g., 'non-quantized') - required for v3 format", + ) + parser.add_argument( + "--git_sha", + type=str, + default=None, + help="Git commit SHA - required for v3 format", + ) + parser.add_argument( + "--workflow_run_id", + type=str, + default=None, + help="GitHub workflow run ID - required for v3 format", + ) + parser.add_argument( + "--workflow_run_url", + type=str, + default="", + help="GitHub workflow run URL - optional for v3 format", + ) + parser.add_argument( + "--gpu_name", + type=str, + default=None, + help="GPU device name (e.g., 'Tesla V100', 'A100') - optional for v3 format", + ) + parser.add_argument( + "--cuda_driver_version", + type=str, + default=None, + help="CUDA driver version (e.g., '12.6', '535.104.05') - optional for v3 format", + ) + parser.add_argument("--verbose", action="store_true", help="Print verbose output") + + args = parser.parse_args() + + warmup_runs = ( + int(0.1 * args.num_runs) if args.warmup_runs == 0 else args.warmup_runs + ) + + print(f"Running benchmark for model: {args.model_name}") + print(f"Number of runs: {args.num_runs}") + if warmup_runs > 0: + print(f"Warmup runs: {warmup_runs}") + if args.fix_gpu_clock: + clock_str = f"{args.gpu_clock}" if args.gpu_clock else "max available" + print(f"GPU clock will be fixed to: {clock_str} MHz") + print(f"Command: {args.runner_command}\n") + + # Fix GPU clocks if requested + gpu_clock_fixed = False + if args.fix_gpu_clock: + # Get current clocks before fixing + initial_clocks = get_gpu_clocks() + if initial_clocks: + print( + f"Current GPU clocks - GPU: {initial_clocks[0]} MHz, " + f"Memory: {initial_clocks[1]} MHz" + ) + + gpu_clock_fixed = set_gpu_clocks(args.gpu_clock) + if not gpu_clock_fixed: + print( + "Warning: Failed to fix GPU clocks. " + "Continuing without fixed clocks...", + file=sys.stderr, + ) + + try: + # Run benchmark + results = run_model_benchmark( + command=args.runner_command, + num_runs=args.num_runs, + warmup_runs=warmup_runs, + verbose=args.verbose, + ) + + # Compute and print summary + summary = compute_summary(args.model_name, results, args.num_runs) + print_summary(summary) + + # Save JSON results if requested + if args.output_json: + summary.save_json(args.output_json) + + # Save v3 format if requested + if args.output_v3: + # Validate required parameters for v3 format + if not all( + [args.model, args.quantization, args.git_sha, args.workflow_run_id] + ): + print( + "Error: --output_v3 requires --model, --quantization, " + "--git_sha, and --workflow_run_id", + file=sys.stderr, + ) + sys.exit(1) + + v3_records = summary.to_v3_format( + model=args.model, + quantization=args.quantization, + git_sha=args.git_sha, + workflow_run_id=args.workflow_run_id, + workflow_run_url=args.workflow_run_url, + gpu_name=args.gpu_name if args.gpu_name else "UNKNOWN GPU", + cuda_driver_version=( + args.cuda_driver_version if args.cuda_driver_version else "cuda" + ), + ) + + with open(args.output_v3, "w") as f: + json.dump(v3_records, f, indent=2) + + print(f"✓ v3 format results saved to: {args.output_v3}") + print(f"✓ Generated {len(v3_records)} v3 records for dashboard upload") + + finally: + # Reset GPU clocks if they were fixed + if gpu_clock_fixed: + reset_gpu_clocks() + + +if __name__ == "__main__": + main() diff --git a/.ci/scripts/export_model_cuda_artifact.sh b/.ci/scripts/export_model_artifact.sh similarity index 65% rename from .ci/scripts/export_model_cuda_artifact.sh rename to .ci/scripts/export_model_artifact.sh index 85e34ae5b80..3c173b0ea2a 100755 --- a/.ci/scripts/export_model_cuda_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -5,19 +5,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Export model to CUDA format with optional quantization +# Export model to CUDA/Metal format with optional quantization show_help() { cat << EOF -Usage: export_model_cuda_artifact.sh [quant_name] [output_dir] +Usage: export_model_artifact.sh [quant_name] [output_dir] -Export a HuggingFace model to CUDA format with optional quantization. +Export a HuggingFace model to CUDA/Metal format with optional quantization. Arguments: + device cuda or metal (required) + hf_model HuggingFace model ID (required) Supported models: - mistralai/Voxtral-Mini-3B-2507 - - openai/whisper-small + - openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}) - google/gemma-3-4b-it quant_name Quantization type (optional, default: non-quantized) @@ -29,9 +31,9 @@ Arguments: output_dir Output directory for artifacts (optional, default: current directory) Examples: - export_model_cuda_artifact.sh "openai/whisper-small" - export_model_cuda_artifact.sh "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" - export_model_cuda_artifact.sh "google/gemma-3-4b-it" "non-quantized" "./output" + export_model_artifact.sh metal "openai/whisper-small" + export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" + export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output" EOF } @@ -48,9 +50,22 @@ fi set -eux -HF_MODEL="$1" -QUANT_NAME="${2:-non-quantized}" -OUTPUT_DIR="${3:-.}" +DEVICE="$1" +HF_MODEL="$2" +QUANT_NAME="${3:-non-quantized}" +OUTPUT_DIR="${4:-.}" + +case "$DEVICE" in + cuda) + ;; + metal) + ;; + *) + echo "Error: Unsupported device '$DEVICE'" + echo "Supported devices: cuda, metal" + exit 1 + ;; +esac # Determine model configuration based on HF model ID case "$HF_MODEL" in @@ -62,15 +77,23 @@ case "$HF_MODEL" in PREPROCESSOR_FEATURE_SIZE="128" PREPROCESSOR_OUTPUT="voxtral_preprocessor.pte" ;; - openai/whisper-small) + openai/whisper-*) MODEL_NAME="whisper" TASK="automatic-speech-recognition" MAX_SEQ_LEN="" EXTRA_PIP="librosa" - PREPROCESSOR_FEATURE_SIZE="80" PREPROCESSOR_OUTPUT="whisper_preprocessor.pte" + if [[ "$HF_MODEL" == *"large-v3"* ]]; then + PREPROCESSOR_FEATURE_SIZE="128" + else + PREPROCESSOR_FEATURE_SIZE="80" + fi ;; google/gemma-3-4b-it) + if [ "$DEVICE" = "metal" ]; then + echo "Error: Export for device 'metal' is not yet tested for model '$HF_MODEL'" + exit 1 + fi MODEL_NAME="gemma3" TASK="multimodal-text-to-text" MAX_SEQ_LEN="64" @@ -80,7 +103,7 @@ case "$HF_MODEL" in ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, google/gemma-3-4b-it" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it" exit 1 ;; esac @@ -91,9 +114,17 @@ case "$QUANT_NAME" in EXTRA_ARGS="" ;; quantized-int4-tile-packed) + if [ "$DEVICE" = "metal" ]; then + echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + exit 1 + fi EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" ;; quantized-int4-weight-only) + if [ "$DEVICE" = "metal" ]; then + echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + exit 1 + fi EXTRA_ARGS="--qlinear_encoder 4w" ;; *) @@ -114,12 +145,18 @@ MAX_SEQ_LEN_ARG="" if [ -n "$MAX_SEQ_LEN" ]; then MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN" fi + +DEVICE_ARG="" +if [ "$DEVICE" = "cuda" ]; then + DEVICE_ARG="--device cuda" +fi + optimum-cli export executorch \ --model "$HF_MODEL" \ --task "$TASK" \ - --recipe "cuda" \ + --recipe "$DEVICE" \ --dtype bfloat16 \ - --device cuda \ + ${DEVICE_ARG} \ ${MAX_SEQ_LEN_ARG} \ ${EXTRA_ARGS} \ --output_dir ./ @@ -133,7 +170,7 @@ if [ -n "$PREPROCESSOR_OUTPUT" ]; then fi test -f model.pte -test -f aoti_cuda_blob.ptd +test -f aoti_${DEVICE}_blob.ptd if [ -n "$PREPROCESSOR_OUTPUT" ]; then test -f $PREPROCESSOR_OUTPUT fi @@ -141,10 +178,10 @@ echo "::endgroup::" echo "::group::Store $MODEL_NAME Artifacts" mkdir -p "${OUTPUT_DIR}" -cp model.pte "${OUTPUT_DIR}/" -cp aoti_cuda_blob.ptd "${OUTPUT_DIR}/" +mv model.pte "${OUTPUT_DIR}/" +mv aoti_${DEVICE}_blob.ptd "${OUTPUT_DIR}/" if [ -n "$PREPROCESSOR_OUTPUT" ]; then - cp $PREPROCESSOR_OUTPUT "${OUTPUT_DIR}/" + mv $PREPROCESSOR_OUTPUT "${OUTPUT_DIR}/" fi ls -al "${OUTPUT_DIR}" echo "::endgroup::" diff --git a/.ci/scripts/test_backend.sh b/.ci/scripts/test_backend.sh index a48cc9ec41a..e959a2f074a 100755 --- a/.ci/scripts/test_backend.sh +++ b/.ci/scripts/test_backend.sh @@ -57,8 +57,13 @@ if [[ "$FLOW" == *vulkan* ]]; then fi if [[ "$FLOW" == *arm* ]]; then + # Setup ARM deps. - .ci/scripts/setup-arm-baremetal-tools.sh + if [[ "$FLOW" == *vgf* ]]; then + .ci/scripts/setup-arm-baremetal-tools.sh --enable-mlsdk-deps --install-mlsdk-deps-with-pip + else + .ci/scripts/setup-arm-baremetal-tools.sh + fi source examples/arm/ethos-u-scratch/setup_path.sh if [[ "$FLOW" == *ethos_u* ]]; then @@ -66,6 +71,11 @@ if [[ "$FLOW" == *arm* ]]; then backends/arm/scripts/build_executorch.sh backends/arm/test/setup_testing.sh fi + + if [[ "$FLOW" == *vgf* ]]; then + # Prepare a test runner binary for VKML runtime + backends/arm/test/setup_testing_vkml.sh + fi fi if [[ $IS_MACOS -eq 1 ]]; then diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh old mode 100644 new mode 100755 index d9e527e7c78..2be02460944 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -137,6 +137,53 @@ else QNN_SDK_ROOT="" fi +# Set dynamic max export times +PLATFORM="x86" +if [[ "$(uname)" == "Darwin" ]]; then + PLATFORM="macos" +elif [[ "$(uname -m)" == "aarch64" ]] || [[ "$(uname -m)" == "arm64" ]]; then + PLATFORM="arm64" +fi + +BUFFER_TIME=25 + +# Lookup threshold based on platform:dtype:mode +case "${PLATFORM}:${DTYPE}:${MODE}:${PT2E_QUANTIZE}" in + + # Linux x86 configurations + "x86:fp32:portable:") ACT_EXPORT_TIME=72 ;; + "x86:fp32:xnnpack+custom:") ACT_EXPORT_TIME=276 ;; + "x86:bf16:portable:") ACT_EXPORT_TIME=75 ;; + "x86:bf16:custom:") ACT_EXPORT_TIME=65 ;; + "x86:fp32:xnnpack+custom+qe:") ACT_EXPORT_TIME=285 ;; + "x86:fp32:xnnpack+custom+quantize_kv:") ACT_EXPORT_TIME=295 ;; + "x86:fp32:xnnpack+quantize_kv:") ACT_EXPORT_TIME=356 ;; + "x86:fp32:qnn:16a16w") ACT_EXPORT_TIME=334 ;; + "x86:fp32:qnn:8a8w") ACT_EXPORT_TIME=81 ;; + + # Linux ARM64 configurations + "arm64:fp32:portable:") ACT_EXPORT_TIME=124 ;; + "arm64:fp32:xnnpack+custom:") ACT_EXPORT_TIME=483 ;; + "arm64:bf16:portable:") ACT_EXPORT_TIME=118 ;; + "arm64:bf16:custom:") ACT_EXPORT_TIME=102 ;; + "arm64:fp32:xnnpack+custom+qe:") ACT_EXPORT_TIME=486 ;; + "arm64:fp32:xnnpack+custom+quantize_kv:") ACT_EXPORT_TIME=521 ;; + "arm64:fp32:xnnpack+quantize_kv:") ACT_EXPORT_TIME=514 ;; + + # macOS configurations + "macos:fp32:mps:") ACT_EXPORT_TIME=30 ;; + "macos:fp32:coreml:") ACT_EXPORT_TIME=61 ;; + "macos:fp32:xnnpack+custom+quantize_kv:") ACT_EXPORT_TIME=133 ;; + + # Default fallback for unknown configurations + *) + ACT_EXPORT_TIME=450 + echo "Warning: No threshold defined for ${PLATFORM}:${DTYPE}:${MODE}:${PT2E_QUANTIZE}, using default: $((ACT_EXPORT_TIME + BUFFER_TIME))s" + ;; +esac + +MAX_EXPORT_TIME=$((ACT_EXPORT_TIME + BUFFER_TIME)) + echo "QNN option ${QNN}" echo "QNN_SDK_ROOT: ${QNN_SDK_ROOT}" @@ -171,15 +218,14 @@ cmake_build_llama_runner() { git submodule update --init popd dir="examples/models/llama" - retry cmake \ - -DEXECUTORCH_BUILD_TESTS=ON \ - -DBUILD_TESTING=OFF \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \ - -Bcmake-out/${dir} \ - ${dir} - cmake --build cmake-out/${dir} -j9 --config "$CMAKE_BUILD_TYPE" - + if [[ "$CMAKE_BUILD_TYPE" == "Debug" ]]; then + PRESET="llama-debug" + else + PRESET="llama-release" + fi + pushd "${dir}" + cmake --workflow --preset "${PRESET}" + popd } cleanup_files() { @@ -255,9 +301,24 @@ fi if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then EXPORT_ARGS="${EXPORT_ARGS} model.quantize_kv_cache=true" fi + +EXPORT_START_TIME=$(date +%s) + # Add dynamically linked library location $PYTHON_EXECUTABLE -m extension.llm.export.export_llm ${EXPORT_ARGS} +EXPORT_END_TIME=$(date +%s) +EXPORT_DURATION=$((EXPORT_END_TIME - EXPORT_START_TIME)) +echo "Model export completed at $(date +"%Y-%m-%d %H:%M:%S") - Duration: ${EXPORT_DURATION} seconds" + +# Check export time against threshold. Default is 500 seconds. +if [ $EXPORT_DURATION -gt $MAX_EXPORT_TIME ]; then + echo "Failure: Export took ${EXPORT_DURATION}s (threshold: ${MAX_EXPORT_TIME}s). This PR may have regressed export time — review changes or bump the threshold if appropriate." +fi + +echo "Success; Export time check passed: ${EXPORT_DURATION}s <= ${MAX_EXPORT_TIME}s" + + # Create tokenizer.bin. echo "Creating tokenizer.bin" $PYTHON_EXECUTABLE -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh index 73efe096f8f..fbcb50b5895 100644 --- a/.ci/scripts/test_llama_lora.sh +++ b/.ci/scripts/test_llama_lora.sh @@ -12,10 +12,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" cmake_install_executorch_libraries() { echo "Installing libexecutorch.a, libextension_module.so, libportable_ops_lib.a" rm -rf cmake-out - retry cmake --preset llm \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release - cmake --build cmake-out -j9 --target install --config Release + cmake --workflow llm-release } cmake_build_llama_runner() { diff --git a/.ci/scripts/test_model_cuda_e2e.sh b/.ci/scripts/test_model_e2e.sh similarity index 77% rename from .ci/scripts/test_model_cuda_e2e.sh rename to .ci/scripts/test_model_e2e.sh index 02845bf4b96..e26a843733f 100755 --- a/.ci/scripts/test_model_cuda_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -5,19 +5,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Test CUDA model end-to-end, need to run .ci/scripts/export_model_cuda_artifact.sh first +# Test CUDA/Metal model end-to-end, need to run .ci/scripts/export_model_artifact.sh first show_help() { cat << EOF -Usage: test_model_cuda_e2e.sh [model_dir] +Usage: test_model_e2e.sh [model_dir] -Build and run end-to-end tests for CUDA models. +Build and run end-to-end tests for CUDA/Metal models. Arguments: + device cuda or metal (required) + hf_model HuggingFace model ID (required) Supported models: - mistralai/Voxtral-Mini-3B-2507 - - openai/whisper-small + - openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}) - google/gemma-3-4b-it quant_name Quantization type (required) @@ -27,12 +29,12 @@ Arguments: - quantized-int4-weight-only model_dir Directory containing model artifacts (optional, default: current directory) - Expected files: model.pte, aoti_cuda_blob.ptd + Expected files: model.pte, aoti_cuda_blob.ptd/aoti_metal_blob.ptd Tokenizers and test files will be downloaded to this directory Examples: - test_model_cuda_e2e.sh "openai/whisper-small" "non-quantized" - test_model_cuda_e2e.sh "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output" + test_model_e2e.sh metal "openai/whisper-small" "non-quantized" + test_model_e2e.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output" EOF } @@ -55,20 +57,21 @@ fi set -eux -HF_MODEL="$1" -QUANT_NAME="$2" +DEVICE="$1" +HF_MODEL="$2" +QUANT_NAME="$3" # Download tokenizers, audio, and image files to this directory -MODEL_DIR="${3:-.}" +MODEL_DIR="${4:-.}" echo "Testing model: $HF_MODEL (quantization: $QUANT_NAME)" -# Make sure model.pte and aoti_cuda_blob.ptd exist +# Make sure model.pte and aoti_${DEVICE}_blob.ptd exist if [ ! -f "$MODEL_DIR/model.pte" ]; then echo "Error: model.pte not found in $MODEL_DIR" exit 1 fi -if [ ! -f "$MODEL_DIR/aoti_cuda_blob.ptd" ]; then - echo "Error: aoti_cuda_blob.ptd not found in $MODEL_DIR" +if [ ! -f "$MODEL_DIR/aoti_${DEVICE}_blob.ptd" ]; then + echo "Error: aoti_${DEVICE}_blob.ptd not found in $MODEL_DIR" exit 1 fi # Locate EXECUTORCH_ROOT from the directory of this script @@ -91,13 +94,13 @@ case "$HF_MODEL" in AUDIO_FILE="poem.wav" IMAGE_PATH="" ;; - openai/whisper-small) - MODEL_NAME="whisper" + openai/whisper-*) + MODEL_NAME="${HF_MODEL#openai/}" RUNNER_TARGET="whisper_runner" RUNNER_PATH="whisper" EXPECTED_OUTPUT="Mr. Quilter is the apostle of the middle classes" PREPROCESSOR="whisper_preprocessor.pte" - TOKENIZER_URL="https://huggingface.co/openai/whisper-small/resolve/main" # @lint-ignore + TOKENIZER_URL="https://huggingface.co/${HF_MODEL}/resolve/main" # @lint-ignore TOKENIZER_FILE="" AUDIO_URL="" AUDIO_FILE="output.wav" @@ -117,7 +120,7 @@ case "$HF_MODEL" in ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, google/gemma-3-4b-it" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it" exit 1 ;; esac @@ -142,7 +145,7 @@ fi # Download test files if [ "$AUDIO_URL" != "" ]; then curl -L $AUDIO_URL -o ${MODEL_DIR}/$AUDIO_FILE -elif [ "$MODEL_NAME" = "whisper" ]; then +elif [[ "$MODEL_NAME" == *whisper* ]]; then conda install -y -c conda-forge "ffmpeg<8" pip install datasets soundfile torchcodec python -c "from datasets import load_dataset;import soundfile as sf;sample = load_dataset('distil-whisper/librispeech_long', 'clean', split='validation')[0]['audio'];sf.write('${MODEL_DIR}/$AUDIO_FILE', sample['array'][:sample['sampling_rate']*30], sample['sampling_rate'])" @@ -152,34 +155,32 @@ ls -al echo "::endgroup::" echo "::group::Build $MODEL_NAME Runner" -cmake --preset llm \ - -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out -S. -cmake --build cmake-out -j$(nproc) --target install --config Release - -cmake -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -Sexamples/models/$RUNNER_PATH \ - -Bcmake-out/examples/models/$RUNNER_PATH/ -cmake --build cmake-out/examples/models/$RUNNER_PATH --target $RUNNER_TARGET --config Release + +if [ "$DEVICE" != "cuda" ] && [ "$DEVICE" != "metal" ]; then + echo "Error: Unsupported device '$DEVICE'. Must be 'cuda' or 'metal'." + exit 1 +fi + +MAKE_TARGET="${RUNNER_PATH}-${DEVICE}" +make "${MAKE_TARGET}" echo "::endgroup::" echo "::group::Run $MODEL_NAME Runner" set +e -export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH +if [ "$DEVICE" = "cuda" ]; then + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH +fi # Build runner command with common arguments RUNNER_BIN="cmake-out/examples/models/$RUNNER_PATH/$RUNNER_TARGET" -RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd --temperature 0" +RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_${DEVICE}_blob.ptd --temperature 0" # Add model-specific arguments case "$MODEL_NAME" in voxtral) RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR" ;; - whisper) + whisper-*) RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR" ;; gemma3) diff --git a/.ci/scripts/test_phi_3_mini.sh b/.ci/scripts/test_phi_3_mini.sh index 24ba4e0a1b5..086822bbad4 100644 --- a/.ci/scripts/test_phi_3_mini.sh +++ b/.ci/scripts/test_phi_3_mini.sh @@ -23,8 +23,16 @@ if hash nproc &> /dev/null; then NPROC=$(nproc); fi cmake_install_executorch_libraries() { rm -rf cmake-out - cmake --preset llm -DCMAKE_INSTALL_PREFIX=cmake-out -DCMAKE_BUILD_TYPE=${BUILD_TYPE} - cmake --build cmake-out -j16 --target install --config ${BUILD_TYPE} + + # Select workflow preset based on BUILD_TYPE + if [[ "${BUILD_TYPE}" == "Debug" ]]; then + WORKFLOW_PRESET="llm-debug" + else + WORKFLOW_PRESET="llm-release" + fi + + echo "Using workflow preset: ${WORKFLOW_PRESET}" + cmake --workflow --preset ${WORKFLOW_PRESET} } cmake_build_phi_3_mini() { diff --git a/.ci/scripts/test_qnn_static_llm.sh b/.ci/scripts/test_qnn_static_llm.sh index 9d1c82f12d5..6b105d1c6f2 100644 --- a/.ci/scripts/test_qnn_static_llm.sh +++ b/.ci/scripts/test_qnn_static_llm.sh @@ -81,7 +81,7 @@ elif [[ "${TASK_NAME}" == "stories_260k_bc" ]]; then fi elif [[ "${TASK_NAME}" == "smollm2_135m" ]]; then - $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_smollm2 --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64 + $PYTHON_EXECUTABLE backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_static_llm_model --model_name smollm2_135m --model SM8650 --build_folder build-x86/ --executorch_root . --artifact_dir ./static_smollm2 --enable_x86_64 exit_code1=$? if [ $exit_code1 -ne 0 ]; then exit 1 diff --git a/.ci/scripts/utils.sh b/.ci/scripts/utils.sh index 8f48e75e712..7fb7517e771 100644 --- a/.ci/scripts/utils.sh +++ b/.ci/scripts/utils.sh @@ -84,8 +84,8 @@ dedupe_macos_loader_path_rpaths() { install_domains() { echo "Install torchvision and torchaudio" - pip install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" - pip install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" + pip install --no-build-isolation --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}" + pip install --no-build-isolation --user "git+https://github.com/pytorch/vision.git@${TORCHVISION_VERSION}" } install_pytorch_and_domains() { diff --git a/.github/scripts/trigger_cuda_perf.sh b/.github/scripts/trigger_cuda_perf.sh new file mode 100755 index 00000000000..402dd009673 --- /dev/null +++ b/.github/scripts/trigger_cuda_perf.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Quick script to trigger cuda-perf workflow via GitHub CLI +# Usage: +# ./trigger_cuda_perf.sh # Use defaults (random model + quant) +# ./trigger_cuda_perf.sh --all # Run ALL models with ALL quantizations +# ./trigger_cuda_perf.sh "openai/whisper-medium" # Single model +# ./trigger_cuda_perf.sh "openai/whisper-small,google/gemma-3-4b-it" "non-quantized,quantized-int4-tile-packed" "100" + +set -e + +# All available models and quantizations +ALL_MODELS="mistralai/Voxtral-Mini-3B-2507,openai/whisper-small,openai/whisper-medium,openai/whisper-large-v3-turbo,google/gemma-3-4b-it" +ALL_QUANTIZATIONS="non-quantized,quantized-int4-tile-packed,quantized-int4-weight-only" + +# Check if gh CLI is installed +if ! command -v gh &> /dev/null; then + echo "Error: GitHub CLI (gh) is not installed." + echo "Install it from: https://cli.github.com/" + echo "" + echo "Quick install:" + echo " macOS: brew install gh" + echo " Linux: See https://github.com/cli/cli/blob/trunk/docs/install_linux.md" + exit 1 +fi + +# Check for --all flag +RUN_ALL=false +if [ "${1:-}" = "--all" ] || [ "${1:-}" = "-a" ]; then + RUN_ALL=true + shift # Remove the flag from arguments +fi + +# Default parameters +if [ "$RUN_ALL" = true ]; then + MODELS="$ALL_MODELS" + QUANT="$ALL_QUANTIZATIONS" + NUM_RUNS="${1:-50}" + RANDOM_MODEL="false" + echo "=========================================" + echo "Triggering cuda-perf workflow" + echo "Mode: RUN ALL MODELS AND QUANTIZATIONS" + echo "=========================================" + echo "Models: ALL (5 models)" + echo "Quantizations: ALL (3 quantizations)" + echo "Total configs: 15 combinations" + echo "Num runs: $NUM_RUNS" + echo "=========================================" +else + MODELS="${1:-}" + QUANT="${2:-}" + NUM_RUNS="${3:-50}" + RANDOM_MODEL="${4:-false}" + + # Display configuration + echo "=========================================" + echo "Triggering cuda-perf workflow" + echo "=========================================" + if [ -z "$MODELS" ]; then + echo "Models: (random selection)" + else + echo "Models: $MODELS" + fi + if [ -z "$QUANT" ]; then + echo "Quantizations: (random selection)" + else + echo "Quantizations: $QUANT" + fi + echo "Num runs: $NUM_RUNS" + echo "Random model: $RANDOM_MODEL" + echo "=========================================" +fi + +echo "" + +# Trigger workflow +gh workflow run cuda-perf.yml \ + -R pytorch/executorch \ + -f models="$MODELS" \ + -f quantizations="$QUANT" \ + -f num_runs="$NUM_RUNS" \ + -f random_model="$RANDOM_MODEL" + +if [ $? -eq 0 ]; then + echo "✓ Workflow triggered successfully!" + echo "" + echo "View status:" + echo " gh run list --workflow=cuda-perf.yml" + echo "" + echo "Watch the latest run:" + echo " gh run watch \$(gh run list --workflow=cuda-perf.yml --limit 1 --json databaseId --jq '.[0].databaseId')" +else + echo "✗ Failed to trigger workflow" + exit 1 +fi diff --git a/.github/workflows/add-unanswered-to-project.yml b/.github/workflows/add-unanswered-to-project.yml index 8b8114d0c04..5321d0f75e2 100644 --- a/.github/workflows/add-unanswered-to-project.yml +++ b/.github/workflows/add-unanswered-to-project.yml @@ -20,31 +20,32 @@ jobs: // List of authors to exclude const excludedAuthors = new Set([ - "nil-is-all", "cbilgin", "kimishpatel", "psiddh", "digantdesai", "SS-JIA", "ahmtox", "mcr229", "shoumikhin", - "manuelcandales", "metascroy", "cccclai", "rohansjoshi", "kirklandsign", "abhinaykukkadapu", "JacobSzwejbka", - "Conarnar", "lucylq", "larryliu0820", "BujSet", "Gasoonjia", "Juntian777", "guangy10", "jackzhxng", - "GregoryComer", "leafs1", "swolchok", "mergennachin", "tarun292", "byjlw", "jathu", "Jack-Khuu", "georgehong", + "nil-is-all", "tanvirislam-meta", "cbilgin", "kimishpatel", "psiddh", "digantdesai", "SS-JIA", "ahmtox", "mcr229", + "shoumikhin", "manuelcandales", "metascroy", "cccclai", "rohansjoshi", "kirklandsign", "abhinaykukkadapu", + "JacobSzwejbka", "Conarnar", "lucylq", "larryliu0820", "BujSet", "Gasoonjia", "Juntian777", "guangy10", "jackzhxng", + "GregoryComer", "leafs1", "swolchok", "mergennachin", "tarun292", "byjlw", "jathu", "Jack-Khuu", "georgehong", "zhenyan-zhang-meta", "silverguo", "harishs88ss", "AlannaBurke", "dbort", "huydhn", "mcremon-meta", "trivedivivek", - "angelayi", "helunwencser", "hsharma35", "zhxchen17", "iseeyuan", "svekars", "nathanaelsee", "dulinriley", "jerryzh168", - "cmodi-meta", "bigfootjon", "sxu", "ydwu4", "Riandy", "tugsbayasgalan", "bsoyluoglu", "yangw-dev", "YIWENX14", - "namanahuja", "yushangdi", "limintang", "pianpwk", "viveknayakatmeta", "andreanicastro", "JakeStevens", + "angelayi", "helunwencser", "hsharma35", "zhxchen17", "iseeyuan", "svekars", "nathanaelsee", "dulinriley", + "jerryzh168", "cmodi-meta", "bigfootjon", "sxu", "ydwu4", "Riandy", "tugsbayasgalan", "bsoyluoglu", "yangw-dev", + "YIWENX14", "namanahuja", "yushangdi", "limintang", "pianpwk", "viveknayakatmeta", "andreanicastro", "JakeStevens", "gmagogsfm", "zonglinpeng", "eigen-k", "derekxu", "salilsdesai", "skrtskrtfb", "pssrawat", "r-barnes", "kalpit-meta-1", "Will-MingLun-Li", "KapJI", "piyengar", "j-bahr", "BoyuanFeng", "fgasperij", "DariusHolmgren", "sammarden-meta", "kushrast", "meta-emilian", "Rittzz", "jeanschmidt", "copyrightly", "mikekgfb", "vmpuri", - "zonglinpengmeta", "maggiemoss", "aorenste", "hoangminhle98", "Solumin", "meyering", "rchen152", - "AishwaryaSivaraman", "migeed-z", "ebgraham", "Esteb37", "nausicaasnow", "Camyll", "ezyang", "huiyujie", - "dltn", "cjhopman", "blackm00n", "agunapal", "SamGondelman", "Ninja91", "ivayloen", "DrJessop", "rodrigos01meta", - "akrieger", "cmt0", "yiming0416", "ethansfng", "ThomasJannaud", "nirvanagth", "marcinkwiatkowski", "3l1", - "omerjerk", "nitish2112", "yipjustin", "ejnguyen", "andrewor14", "phaiting", "mgiordy", "LeeOHzzZ", "adicatana", - "Polyomino", "ezrilow", "navsud", "YifanShenSZ", "RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat", "azad-meta", - "pytorchbot", "pytorchmergebot", "pytorchupdatebot", "facebook-github-bot", "app/dependabot", "Erik-Lundell", - "zingo", "AdrianLundell", "oscarandersson8218", "per", "Sebastian-Larsson", "SaoirseARM", "robell", "mansnils", - "martinlsm", "freddan80", "YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind", + "zonglinpengmeta", "maggiemoss", "aorenste", "hoangminhle98", "Solumin", "meyering", "rchen152", "AishwaryaSivaraman", + "migeed-z", "ebgraham", "Esteb37", "nausicaasnow", "Camyll", "ezyang", "huiyujie", "dltn", "cjhopman", "blackm00n", + "agunapal", "SamGondelman", "Ninja91", "ivayloen", "DrJessop", "rodrigos01meta", "akrieger", "cmt0", "yiming0416", + "ethansfng", "ThomasJannaud", "nirvanagth", "marcinkwiatkowski", "3l1", "omerjerk", "nitish2112", "yipjustin", + "ejnguyen", "andrewor14", "phaiting", "mgiordy", "LeeOHzzZ", "adicatana", "Polyomino", "ezrilow", "navsud", + "michaelmaitland", "RahulC7", "seyeong-han", "YifanShenSZ", "RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat", + "azad-meta", "junpi", "pytorchbot", "pytorchmergebot", "pytorchupdatebot", "facebook-github-bot", "app/dependabot", + "Erik-Lundell", "zingo", "AdrianLundell", "oscarandersson8218", "per", "Sebastian-Larsson", "SaoirseARM", "robell", + "mansnils", "martinlsm", "freddan80", "YufengShi-dudu", "tom-arm", "perheld", "Jerry-Ge", "gggekov", "fumchin", "wwwind", "benkli01", "Tessil", "maddun01", "Michiel-Olieslagers", "armwaheed", "agrima1304", "emmakujala", "annietllnd", - "haowhsu-quic", "shewu-quic", "winskuo-quic", "chunit-quic", "DannyYuyang-quic", "chuntl", "thchenqti", - "jethroqti", "cymbalrush", "DenisVieriu97", "billmguo", "StrycekSimon", "jirioc", "robert-kalmar", "skywall", - "MartinPavella", "roman-janik-nxp", "novak-vaclav ", "neuropilot-captain", "dijopaul", "cad-rlc", "cad-audio", - "ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", "cavusmustafa", "Jiseong-oh", "alexdean08" + "MatthiasHertel80", "AlexTawseArm", "jmahbs", "haowhsu-quic", "shewu-quic", "winskuo-quic", "chunit-quic", + "DannyYuyang-quic", "chuntl", "thchenqti", "jethroqti", "chenweng-quic", "cymbalrush", "DenisVieriu97", "billmguo", + "StrycekSimon", "jirioc", "robert-kalmar", "skywall", "MartinPavella", "roman-janik-nxp", "novak-vaclav ", + "neuropilot-captain", "dijopaul", "cad-rlc", "cad-audio", "ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", + "cavusmustafa", "Jiseong-oh", "alexdean08" ]); async function addItem(contentId, type, number) { diff --git a/.github/workflows/android-perf-private-device-experiment.yml b/.github/workflows/android-perf-private-device-experiment.yml deleted file mode 100644 index cf37538f620..00000000000 --- a/.github/workflows/android-perf-private-device-experiment.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: android-perf (private devices) - -on: - schedule: - - cron: 0 0,4,8,12,16,20 * * * - pull_request: - paths: - - .github/workflows/android-perf-private-device-experiment.yml - push: - branches: - - main - paths: - - .github/workflows/android-perf-private-device-experiment.yml - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: android-perf-private-devices-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - android: - uses: ./.github/workflows/android-perf.yml - secrets: inherit - permissions: - id-token: write - contents: read - with: - models: ${{ inputs.models || github.event_name == 'schedule' && 'Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,meta-llama/Llama-3.2-1B,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'google/gemma-3-1b-it' }} - devices: samsung_galaxy_s22+private - benchmark_configs: ${{ inputs.benchmark_configs }} diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml deleted file mode 100644 index 33937531a01..00000000000 --- a/.github/workflows/android-perf.yml +++ /dev/null @@ -1,562 +0,0 @@ -name: android-perf - -on: - schedule: - - cron: 0 0,8,16 * * * - pull_request: - paths: - - .github/workflows/android-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 - push: - branches: - - main - paths: - - .github/workflows/android-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: samsung_galaxy_s22+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - set-parameters: - runs-on: ubuntu-22.04 - outputs: - benchmark_configs: ${{ steps.set-parameters.outputs.benchmark_configs }} - steps: - - uses: actions/checkout@v3 - with: - submodules: 'false' - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Set parameters - id: set-parameters - shell: bash - env: - # Separate default values from the workflow dispatch. To ensure defaults are accessible - # during scheduled runs and to provide flexibility for different defaults between - # on-demand and periodic benchmarking. - CRON_DEFAULT_MODELS: ${{ github.event_name == 'schedule' && 'mv3,mv2,ic4,ic3,resnet50,mobilebert,w2l,meta-llama/Llama-3.2-1B,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8,Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'Qwen/Qwen3-0.6B' }} - CRON_DEFAULT_DEVICES: samsung_galaxy_s22+public - run: | - set -eux - - ARGS="--os android" - - MODELS="${{ inputs.models }}" - if [ -z "$MODELS" ]; then - MODELS="$CRON_DEFAULT_MODELS" - fi - ARGS="$ARGS --models $MODELS" - - DEVICES="${{ inputs.devices }}" - if [ -z "$DEVICES" ]; then - DEVICES="$CRON_DEFAULT_DEVICES" - fi - ARGS="$ARGS --devices $DEVICES" - - BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}" - if [ -n "$BENCHMARK_CONFIGS" ]; then - ARGS="$ARGS --configs $BENCHMARK_CONFIGS" - fi - - PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS - - prepare-test-specs: - runs-on: linux.2xlarge - needs: set-parameters - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - steps: - - uses: actions/checkout@v3 - - - name: Prepare the spec - id: prepare - shell: bash - env: - BENCHMARK_CONFIG: ${{ toJSON(matrix) }} - working-directory: extension/benchmark/android/benchmark - run: | - set -eux - - # The model will be exported in the next step to this S3 path - MODEL_PATH="https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/model.zip" - # We could write a script to properly use jinja here, but there is only one variable, - # so let's just sed it - sed -i -e 's,{{ model_path }},'"${MODEL_PATH}"',g' android-llm-device-farm-test-spec.yml.j2 - - BENCHMARK_CONFIG_ID=$(echo "${{ matrix.model }}_${{ matrix.config }}" | sed -e 's/[^A-Za-z0-9._-]/_/g') - # The config for this benchmark runs, we save it in the test spec so that it can be fetched - # later by the upload script - sed -i -e 's,{{ benchmark_config_id }},'"${BENCHMARK_CONFIG_ID}"',g' android-llm-device-farm-test-spec.yml.j2 - - cp android-llm-device-farm-test-spec.yml.j2 android-llm-device-farm-test-spec.yml - # Just print the test spec for debugging - cat android-llm-device-farm-test-spec.yml - - # Save the benchmark configs so that we can use it later in the dashboard - echo "${BENCHMARK_CONFIG}" > "${BENCHMARK_CONFIG_ID}.json" - echo "benchmark-config-id=${BENCHMARK_CONFIG_ID}" >> $GITHUB_OUTPUT - - - name: Upload the spec - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }} - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml - - - name: Update the benchmark configs - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/android/benchmark/${{ steps.prepare.outputs.benchmark-config-id }}.json - - export-models: - name: export-models - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - needs: set-parameters - secrets: inherit - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - runner: linux.2xlarge.memory - docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk - submodules: 'recursive' - timeout: 60 - upload-artifact: android-models - upload-artifact-to-s3: true - secrets-env: EXECUTORCH_HF_TOKEN - script: | - # The generic Linux job chooses to use base env, not the one setup by the image - echo "::group::Setting up dev environment" - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - if [[ ${{ matrix.config }} == *"qnn"* ]]; then - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh - PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh - fi - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake" - # Install requirements for export_llama - PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - - pip install -U "huggingface_hub[cli]" - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - pip install accelerate sentencepiece - pip list - - ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }}_${{ matrix.config }} - echo "::endgroup::" - - echo "::group::Exporting ${{ matrix.config }} model: ${{ matrix.model }}" - BUILD_MODE="cmake" - - if [[ ${{ matrix.model }} =~ ^[^/]+/[^/]+$ ]]; then - # HuggingFace model. Assume the pattern is always like "/" - HF_MODEL_REPO=${{ matrix.model }} - OUT_ET_MODEL_NAME="$(echo "$HF_MODEL_REPO" | awk -F'/' '{print $2}' | sed 's/_/-/g' | tr '[:upper:]' '[:lower:]')_${{ matrix.config }}" - - # Convert HF checkpoint to ET via etLLM path - if [[ "$HF_MODEL_REPO" == meta-llama/* ]]; then - if [[ ${{ matrix.config }} == "llama3_spinquant" ]]; then - # SpinQuant - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - model.use_kv_cache=true \ - model.dtype_override=fp32 \ - base.preq_embedding_quantize=\'8,0\' \ - quantization.use_spin_quant=native \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_qlora" ]]; then - # QAT + LoRA - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - quantization.use_qat=true \ - base.use_lora=16 \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - base.preq_embedding_quantize=\'8,0\' \ - model.use_sdpa_with_kv_cache=true \ - model.use_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - model.dtype_override=fp32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_fb16" ]]; then - # Original BF16 version, without any quantization - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - model.dtype_override=bf16 \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - python -m extension.llm.export.export_llm \ - base.model_class=llama3_2 \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_qnn_htp" ]]; then - export QNN_SDK_ROOT=/tmp/qnn/2.37.0.250724 - export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/ - export PYTHONPATH=$(pwd)/.. - - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - python -m examples.qualcomm.oss_scripts.llama3_2.llama -- \ - --checkpoint "${DOWNLOADED_PATH}/consolidated.00.pth" \ - --params "${DOWNLOADED_PATH}/params.json" \ - --tokenizer_model "${DOWNLOADED_PATH}/tokenizer.model" \ - --compile_only \ - --ptq 16a4w \ - -m SM8650 \ - --model_size 1B \ - --model_mode kv \ - --prompt "Once" - - OUT_ET_MODEL_NAME="llama3_2_qnn" # Qualcomm hard-coded it in their script - find . -name "${OUT_ET_MODEL_NAME}.pte" -not -path "./${OUT_ET_MODEL_NAME}.pte" -exec mv {} ./ \; - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - elif [[ "$HF_MODEL_REPO" == "Qwen/Qwen3-0.6B" ]]; then - if [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") - python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/config/0_6b_config.json \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":151644,\"get_eos_ids\":[151645]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - fi - - if [[ ${{ matrix.config }} == "hf_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.json" - ) - echo "tokenizer.json is downloaded to $DOWNLOADED_PATH" - - # Install optimum-executorch - OPTIMUM_ET_COMMIT=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - git clone https://github.com/huggingface/optimum-executorch - pushd optimum-executorch - # There is no release yet, for CI stability, always test from the same commit on main - git checkout $OPTIMUM_ET_COMMIT - python install_dev.py --skip_override_torch - pip list - - ARGS=( - "--model" "${HF_MODEL_REPO}" - "--task" "text-generation" - "--recipe" "xnnpack" - "--use_custom_sdpa" - "--use_custom_kv_cache" - "--qlinear" "8da4w" - "--qembedding" "8w" - "--output_dir" ".." - ) - - optimum-cli export executorch "${ARGS[@]}" - popd - - mv model.pte ${OUT_ET_MODEL_NAME}.pte - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - - zip -j model.zip ${OUT_ET_MODEL_NAME}.pte ${DOWNLOADED_PATH}/tokenizer.* - ls -lh model.zip - mkdir -p ${ARTIFACTS_DIR_NAME} - mv model.zip ${ARTIFACTS_DIR_NAME} - ls -lh ${ARTIFACTS_DIR_NAME} - elif [[ ${{ matrix.model }} == "llama" ]]; then - # Install requirements for export_llama - PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh - # Test llama2 - if [[ ${{ matrix.config }} == *"xnnpack"* ]]; then - DELEGATE_CONFIG="xnnpack+custom+qe" - elif [[ ${{ matrix.config }} == *"qnn"* ]]; then - DELEGATE_CONFIG="qnn" - else - echo "Unsupported delegate ${{ matrix.config }}" - exit 1 - fi - DTYPE="fp32" - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh \ - -model "${{ matrix.model }}" \ - -build_tool "${BUILD_MODE}" \ - -dtype "${DTYPE}" \ - -mode "${DELEGATE_CONFIG}" \ - -upload "${ARTIFACTS_DIR_NAME}" - else - PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh \ - "${{ matrix.model }}" \ - "${BUILD_MODE}" \ - "${{ matrix.config }}" \ - "${ARTIFACTS_DIR_NAME}" - fi - echo "::endgroup::" - - build-benchmark-app: - name: build-benchmark-app - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - needs: set-parameters - with: - runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-clang12-android - submodules: 'recursive' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 90 - upload-artifact: android-apps - upload-artifact-to-s3: true - script: | - set -eux - - # Use sccache for NDK compiler as well - export CMAKE_CXX_COMPILER_LAUNCHER=sccache - export CMAKE_C_COMPILER_LAUNCHER=sccache - - # The generic Linux job chooses to use base env, not the one setup by the image - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake - export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded - - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh - PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh - - mkdir -p aar-out - PYTHON_EXECUTABLE=python ANDROID_ABIS="arm64-v8a" BUILD_AAR_DIR=aar-out EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.37.0.250724 EXECUTORCH_ANDROID_PROFILING=ON bash scripts/build_android_library.sh - mkdir -p extension/benchmark/android/benchmark/app/libs - cp aar-out/executorch.aar extension/benchmark/android/benchmark/app/libs - pushd extension/benchmark/android/benchmark - ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest - popd - MINIBENCH_APP_DIR="${ARTIFACTS_DIR_NAME}/minibench" - mkdir -p "${MINIBENCH_APP_DIR}" - cp extension/benchmark/android/benchmark/app/build/outputs/apk/debug/*.apk "${MINIBENCH_APP_DIR}" - cp extension/benchmark/android/benchmark/app/build/outputs/apk/androidTest/debug/*.apk "${MINIBENCH_APP_DIR}" - - # Let's see how expensive this job is, we might want to tone it down by running it periodically - # CHANGE IF this job name 'benchmark-on-device' changed: extract_model_info() in executorch/.github/scripts/extract_benchmark_results.py - benchmark-on-device: - if: always() - permissions: - id-token: write - contents: read - uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main - needs: - - set-parameters - - prepare-test-specs - - build-benchmark-app - - export-models - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - # Due to scheduling a job may be pushed beyond the default 60m threshold - timeout: 240 - device-type: android - runner: linux.2xlarge - test-infra-ref: '' - # This is the ARN of ExecuTorch project on AWS - project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 - device-pool-arn: ${{ matrix.device_arn }} - android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/minibench/app-debug.apk - android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/minibench/app-debug-androidTest.apk - test-spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/android-llm-device-farm-test-spec.yml - new-output-format-flag: true - - upload-benchmark-results: - needs: - - benchmark-on-device - if: always() - runs-on: linux.2xlarge - environment: upload-benchmark-results - permissions: - id-token: write - contents: read - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - - name: Authenticate with AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results - # The max duration enforced by the server side - role-duration-seconds: 18000 - aws-region: us-east-1 - - - name: Setup conda - uses: pytorch/test-infra/.github/actions/setup-miniconda@main - with: - python-version: '3.10' - - - name: Download the list of artifacts from S3 - env: - ARTIFACTS_S3_DIR: s3://gha-artifacts/device_farm/${{ github.run_id }}/${{ github.run_attempt }}/artifacts/ - shell: bash - run: | - set -eux - ${CONDA_RUN} python -mpip install awscli==1.32.18 - - mkdir -p artifacts - pushd artifacts - ${CONDA_RUN} aws s3 sync "${ARTIFACTS_S3_DIR}" . - popd - - ls -lah artifacts - - - name: Download the list of benchmark configs from S3 - env: - BENCHMARK_CONFIGS_DIR: s3://gha-artifacts/${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - shell: bash - run: | - set -eux - - mkdir -p benchmark-configs - pushd benchmark-configs - ${CONDA_RUN} aws s3 sync "${BENCHMARK_CONFIGS_DIR}" . - popd - - ls -lah benchmark-configs - - - name: Extract the benchmark results JSON - shell: bash - env: - DEVICE_TYPE: android - run: | - set -eux - - mkdir -p benchmark-results - - for ARTIFACTS_BY_JOB in artifacts/*.json; do - [ -f "${ARTIFACTS_BY_JOB}" ] || break - echo "${ARTIFACTS_BY_JOB}" - ${CONDA_RUN} python .github/scripts/extract_benchmark_results.py \ - --artifacts "${ARTIFACTS_BY_JOB}" \ - --output-dir benchmark-results \ - --app "${DEVICE_TYPE}" \ - --benchmark-configs benchmark-configs - done - - for BENCHMARK_RESULTS in benchmark-results/v3/*.json; do - cat "${BENCHMARK_RESULTS}" - echo - done - - - name: Upload the benchmark results (v3) - uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main - with: - benchmark-results-dir: benchmark-results/v3 - dry-run: false - schema-version: v3 - github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/apple-perf-private-device-experiment.yml b/.github/workflows/apple-perf-private-device-experiment.yml deleted file mode 100644 index 47e2c6c9340..00000000000 --- a/.github/workflows/apple-perf-private-device-experiment.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: apple-perf (private devices) - -on: - schedule: - - cron: 0 0,4,8,12,16,20 * * * - pull_request: - paths: - - .github/workflows/apple-perf-private-device-experiment.yml - push: - branches: - - main - paths: - - .github/workflows/apple-perf-private-device-experiment.yml - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+pro_private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+pro_private - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: apple-perf-private-devices-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - apple: - uses: ./.github/workflows/apple-perf.yml - secrets: inherit - permissions: - id-token: write - contents: read - with: - models: ${{ inputs.models || github.event_name == 'schedule' && 'Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,meta-llama/Llama-3.2-1B,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'google/gemma-3-1b-it' }} - devices: apple_iphone_15+pro_private - benchmark_configs: ${{ inputs.benchmark_configs }} diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml deleted file mode 100644 index 56fc67d1617..00000000000 --- a/.github/workflows/apple-perf.yml +++ /dev/null @@ -1,603 +0,0 @@ -name: apple-perf - -on: - schedule: - - cron: 0 1 * * * - pull_request: - paths: - - .github/workflows/apple-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2 - push: - branches: - - main - paths: - - .github/workflows/apple-perf.yml - - .ci/scripts/gather_benchmark_configs.py - - extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml.j2 - # Note: GitHub has an upper limit of 10 inputs - workflow_dispatch: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - workflow_call: - inputs: - models: - description: Models to be benchmarked - required: false - type: string - default: Qwen/Qwen3-0.6B - devices: - description: Target devices to run benchmark - required: false - type: string - default: apple_iphone_15+public - benchmark_configs: - description: The list of configs used the benchmark - required: false - type: string - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - set-parameters: - runs-on: ubuntu-22.04 - outputs: - benchmark_configs: ${{ steps.set-parameters.outputs.benchmark_configs }} - steps: - - uses: actions/checkout@v3 - with: - submodules: 'false' - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Set parameters - id: set-parameters - shell: bash - env: - # Separate default values from the workflow dispatch. To ensure defaults are accessible - # during scheduled runs and to provide flexibility for different defaults between - # on-demand and periodic benchmarking. - CRON_DEFAULT_MODELS: ${{ github.event_name == 'schedule' && 'mv3,mv2,ic4,ic3,resnet50,edsr,mobilebert,w2l,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8,Qwen/Qwen3-0.6B,HuggingFaceTB/SmolLM2-135M,meta-llama/Llama-3.2-1B,allenai/OLMo-1B-hf,google/gemma-3-1b-it' || 'Qwen/Qwen3-0.6B' }} - CRON_DEFAULT_DEVICES: apple_iphone_15+public - run: | - set -eux - - ARGS="--os ios" - - MODELS="${{ inputs.models }}" - if [ -z "$MODELS" ]; then - MODELS="$CRON_DEFAULT_MODELS" - fi - ARGS="$ARGS --models $MODELS" - - DEVICES="${{ inputs.devices }}" - if [ -z "$DEVICES" ]; then - DEVICES="$CRON_DEFAULT_DEVICES" - fi - ARGS="$ARGS --devices $DEVICES" - - BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}" - if [ -n "$BENCHMARK_CONFIGS" ]; then - ARGS="$ARGS --configs $BENCHMARK_CONFIGS" - fi - - PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS - - echo "benchmark_configs is: ${{ steps.set-parameters.outputs.benchmark_configs }}" - - prepare-test-specs: - runs-on: linux.2xlarge - needs: set-parameters - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - steps: - - uses: actions/checkout@v3 - - - name: Prepare the spec - id: prepare - shell: bash - env: - BENCHMARK_CONFIG: ${{ toJSON(matrix) }} - working-directory: extension/benchmark/apple/Benchmark - run: | - set -eux - - # The model will be exported in the next step to this S3 path - MODEL_PATH="https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/model.zip" - # We could write a script to properly use jinja here, but there is only one variable, - # so let's just sed it - sed -i -e 's,{{ model_path }},'"${MODEL_PATH}"',g' default-ios-device-farm-appium-test-spec.yml.j2 - - BENCHMARK_CONFIG_ID=$(echo "${{ matrix.model }}_${{ matrix.config }}" | sed -e 's/[^A-Za-z0-9._-]/_/g') - # The config for this benchmark runs, we save it in the test spec so that it can be fetched - # later by the upload script - sed -i -e 's,{{ benchmark_config_id }},'"${BENCHMARK_CONFIG_ID}"',g' default-ios-device-farm-appium-test-spec.yml.j2 - - cp default-ios-device-farm-appium-test-spec.yml.j2 default-ios-device-farm-appium-test-spec.yml - # Just print the test spec for debugging - cat default-ios-device-farm-appium-test-spec.yml - - # Save the benchmark configs so that we can use it later in the dashboard - echo "${BENCHMARK_CONFIG}" > "${BENCHMARK_CONFIG_ID}.json" - echo "benchmark-config-id=${BENCHMARK_CONFIG_ID}" >> $GITHUB_OUTPUT - - - name: Upload the spec - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }} - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/apple/Benchmark/default-ios-device-farm-appium-test-spec.yml - - - name: Update the benchmark configs - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - retention-days: 1 - if-no-files-found: error - path: extension/benchmark/apple/Benchmark/${{ steps.prepare.outputs.benchmark-config-id }}.json - - export-models: - name: export-models - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - needs: set-parameters - secrets: inherit - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - # NB: Need to use our AWS MacOS runner to upload large models to S3 - runner: macos-m1-stable - python-version: '3.11' - submodules: 'recursive' - timeout: 60 - upload-artifact: ios-models - upload-artifact-to-s3: true - secrets-env: EXECUTORCH_HF_TOKEN - script: | - set -eux - - echo "::group::Setting up CI environment" - .ci/scripts/setup-conda.sh - - BUILD_TOOL=cmake - # Setup MacOS dependencies as there is no Docker support on MacOS atm - GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - .ci/scripts/setup-macos.sh --build-tool "${BUILD_TOOL}" - - if [[ ${{ matrix.config }} == *"coreml"* ]]; then - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - backends/apple/coreml/scripts/install_requirements.sh - fi - - # Install requirements for export_llama - PYTHON_EXECUTABLE=python ${CONDA_RUN} bash examples/models/llama/install_requirements.sh - - pip install -U "huggingface_hub[cli]" - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - ${CONDA_RUN} pip install accelerate sentencepiece - pip list - - ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }}_${{ matrix.config }} - echo "::endgroup::" - - echo "::group::Exporting ${{ matrix.config }} model: ${{ matrix.model }}" - BUILD_MODE="cmake" - - if [[ ${{ matrix.model }} =~ ^[^/]+/[^/]+$ ]]; then - # HuggingFace model. Assume the pattern is always like "/" - HF_MODEL_REPO=${{ matrix.model }} - OUT_ET_MODEL_NAME="$(echo "$HF_MODEL_REPO" | awk -F'/' '{print $2}' | sed 's/_/-/g' | tr '[:upper:]' '[:lower:]')_${{ matrix.config }}" - - # Convert HF checkpoint to ET via etLLM path - if [[ "$HF_MODEL_REPO" == meta-llama/* ]]; then - # The benchmark app replies on the _llm suffix to determine whether the model is a LLM or not - OUT_ET_MODEL_NAME=${OUT_ET_MODEL_NAME}_llm - # Llama models on Hugging Face - if [[ ${{ matrix.config }} == "llama3_spinquant" ]]; then - # SpinQuant - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - model.use_kv_cache=true \ - model.dtype_override=fp32 \ - base.preq_embedding_quantize=\'8,0\' \ - quantization.use_spin_quant=native \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_qlora" ]]; then - # QAT + LoRA - # Download prequantized chceckpoint from Hugging Face - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.model" "params.json" "consolidated.00.pth" - ) - # Export using ExecuTorch's model definition - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - quantization.use_qat=true \ - base.use_lora=16 \ - base.preq_mode="preq_8da4w_out_8da8w" \ - base.preq_group_size=32 \ - base.preq_embedding_quantize=\'8,0\' \ - model.use_sdpa_with_kv_cache=true \ - model.use_kv_cache=true \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - model.dtype_override=fp32 \ - export.max_seq_length=2048 \ - export.max_context_length=2048 \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_fb16" ]]; then - # Original BF16 version, without any quantization - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class="llama3_2" \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - backend.xnnpack.enabled=true \ - model.dtype_override=bf16 \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class=llama3_2 \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":128000,\"get_eos_ids\":[128009,128001]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - elif [[ ${{ matrix.config }} == "llama3_coreml_ane" ]]; then - # ANE - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "original" --files "tokenizer.model" "params.json" "consolidated.00.pth") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ - base.params="${DOWNLOADED_PATH}/params.json" \ - quantization.embedding_quantize=\'4,32\' \ - model.use_kv_cache=true \ - model.enable_dynamic_shape=false \ - backend.coreml.enabled=true \ - backend.coreml.ios=18 \ - backend.coreml.quantize=c4w \ - backend.coreml.compute_units=cpu_and_ne \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - elif [[ "$HF_MODEL_REPO" == "Qwen/Qwen3-0.6B" ]]; then - OUT_ET_MODEL_NAME=${OUT_ET_MODEL_NAME}_llm - if [[ ${{ matrix.config }} == "et_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$(bash .ci/scripts/download_hf_hub.sh --model_id "${HF_MODEL_REPO}" --subdir "." --files "tokenizer.json") - ${CONDA_RUN} python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/config/0_6b_config.json \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=fp32 \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - quantization.qmode=8da4w \ - quantization.group_size=32 \ - quantization.embedding_quantize=\'8,0\' \ - base.metadata='"{\"get_bos_id\":151644,\"get_eos_ids\":[151645]}"' \ - export.output_name="${OUT_ET_MODEL_NAME}.pte" - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - fi - - if [[ ${{ matrix.config }} == "hf_xnnpack_custom_spda_kv_cache_8da4w" ]]; then - DOWNLOADED_PATH=$( - bash .ci/scripts/download_hf_hub.sh \ - --model_id "${HF_MODEL_REPO}" \ - --files "tokenizer.json" - ) - echo "tokenizer.json is downloaded to $DOWNLOADED_PATH" - - # Install optimum-executorch - OPTIMUM_ET_COMMIT=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - git clone https://github.com/huggingface/optimum-executorch - pushd optimum-executorch - # There is no release yet, for CI stability, always test from the same commit on main - git checkout $OPTIMUM_ET_COMMIT - ${CONDA_RUN} python install_dev.py --skip_override_torch - pip list - - ARGS=( - "--model" "${HF_MODEL_REPO}" - "--task" "text-generation" - "--recipe" "xnnpack" - "--use_custom_sdpa" - "--use_custom_kv_cache" - "--qlinear" "8da4w" - "--qembedding" "8w" - "--output_dir" ".." - ) - - ${CONDA_RUN} optimum-cli export executorch "${ARGS[@]}" - popd - - # The benchmark app replies on the _llm suffix to determine whether the model is a LLM or not - OUT_ET_MODEL_NAME=${OUT_ET_MODEL_NAME}_llm - mv model.pte ${OUT_ET_MODEL_NAME}.pte - ls -lh "${OUT_ET_MODEL_NAME}.pte" - fi - - zip -j model.zip ${OUT_ET_MODEL_NAME}.pte ${DOWNLOADED_PATH}/tokenizer.* - ls -lh model.zip - mkdir -p "${ARTIFACTS_DIR_NAME}" - mv model.zip "${ARTIFACTS_DIR_NAME}" - elif [[ ${{ matrix.model }} == "llama" ]]; then - # Install requirements for export_llama - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - bash examples/models/llama/install_requirements.sh - - # Test llama2 - if [[ ${{ matrix.config }} == *"xnnpack"* ]]; then - DELEGATE_CONFIG="xnnpack+custom+qe" - elif [[ ${{ matrix.config }} == *"coreml"* ]]; then - DELEGATE_CONFIG="coreml" - elif [[ ${{ matrix.config }} == *"mps"* ]]; then - DELEGATE_CONFIG="mps" - fi - DTYPE="fp32" - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - bash .ci/scripts/test_llama.sh \ - -model "stories110M" \ - -build_tool "${BUILD_MODE}" \ - -dtype "${DTYPE}" \ - -mode "${DELEGATE_CONFIG}" \ - -upload "${ARTIFACTS_DIR_NAME}" - else - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - bash .ci/scripts/test_model.sh \ - "${{ matrix.model }}" \ - "${BUILD_MODE}" \ - "${{ matrix.config }}" \ - "${ARTIFACTS_DIR_NAME}" - fi - echo "::endgroup::" - - build-benchmark-app: - name: build-benchmark-app - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - needs: - - set-parameters - secrets: inherit - with: - runner: macos-14-xlarge - python-version: '3.11' - submodules: 'recursive' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - upload-artifact: ios-apps - secrets-env: BUILD_CERTIFICATE_BASE64 EXECUTORCH_BENCHMARK_BUILD_PROVISION_PROFILE_BASE64 KEYCHAIN_PASSWORD - timeout: 90 - script: | - set -eux - - echo "::group::Setting up CI environment" - .ci/scripts/setup-conda.sh - - BUILD_TOOL=cmake - # Setup MacOS dependencies as there is no Docker support on MacOS atm - GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - .ci/scripts/setup-macos.sh --build-tool "${BUILD_TOOL}" - export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded - - # Setup Apple certificate for iOS development - BUILD_PROVISION_PROFILE_BASE64="${SECRET_EXECUTORCH_BENCHMARK_BUILD_PROVISION_PROFILE_BASE64}" \ - BUILD_CERTIFICATE_BASE64="${SECRET_BUILD_CERTIFICATE_BASE64}" \ - KEYCHAIN_PASSWORD="${SECRET_KEYCHAIN_PASSWORD}" \ - .ci/scripts/setup-ios.sh - - # Install CoreML Backend Requirements - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ - backends/apple/coreml/scripts/install_requirements.sh - echo "::endgroup::" - - echo "::group::Build ExecuTorch iOS frameworks" - PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output scripts/build_apple_frameworks.sh - echo "::endgroup::" - - # NB: Although exported models can be copied to this directory and bundled together with the - # app, we don't use this in CI and rely on AWS extra data parameter to make the model and the - # tokenizer available to the benchmark. This decouples the app and the model. We just need to - # create the directory here to pass the build - mkdir -p extension/benchmark/apple/Benchmark/Models - ${CONDA_RUN} --no-capture-output \ - scripts/build_apple_llm_demo.sh ${ARTIFACTS_DIR_NAME} - - upload-benchmark-app: - needs: build-benchmark-app - runs-on: linux.2xlarge - steps: - - name: Download the apps from GitHub - uses: actions/download-artifact@v4 - with: - # The name here needs to match the name of the upload-artifact parameter - name: ios-apps - path: ${{ runner.temp }}/artifacts/ - - - name: Verify the apps - shell: bash - working-directory: ${{ runner.temp }}/artifacts/ - run: | - ls -lah ./ - - - name: Upload the apps to S3 - uses: seemethere/upload-artifact-s3@v5 - with: - s3-bucket: gha-artifacts - s3-prefix: | - ${{ github.repository }}/${{ github.run_id }}/artifacts - retention-days: 14 - if-no-files-found: ignore - path: ${{ runner.temp }}/artifacts/ - - # CHANGE IF this job name 'benchmark-on-device' changed: extract_model_info() in executorch/.github/scripts/extract_benchmark_results.py - benchmark-on-device: - if: always() - needs: - - set-parameters - - prepare-test-specs - - upload-benchmark-app - - export-models - permissions: - id-token: write - contents: read - uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main - strategy: - matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} - fail-fast: false - with: - # Due to scheduling a job may be pushed beyond the default 60m threshold - timeout: 120 - device-type: ios - # For iOS testing, the runner just needs to call AWS Device Farm, so there is no need to run this on macOS - runner: linux.2xlarge - test-infra-ref: '' - # This is the ARN of ExecuTorch project on AWS - project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 - device-pool-arn: ${{ matrix.device_arn }} - # Uploaded to S3 from the previous job - ios-ipa-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/Benchmark.ipa - ios-xctestrun-zip: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/Benchmark.xctestrun.zip - test-spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/${{ matrix.model }}_${{ matrix.config }}/default-ios-device-farm-appium-test-spec.yml - new-output-format-flag: true - - upload-benchmark-results: - needs: - - benchmark-on-device - if: always() - runs-on: linux.2xlarge - environment: upload-benchmark-results - permissions: - id-token: write - contents: read - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - - name: Authenticate with AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results - # The max duration enforced by the server side - role-duration-seconds: 18000 - aws-region: us-east-1 - - - name: Setup conda - uses: pytorch/test-infra/.github/actions/setup-miniconda@main - with: - python-version: '3.10' - - - name: Download the list of artifacts from S3 - env: - ARTIFACTS_S3_DIR: s3://gha-artifacts/device_farm/${{ github.run_id }}/${{ github.run_attempt }}/artifacts/ - shell: bash - run: | - set -eux - ${CONDA_RUN} python -mpip install awscli==1.32.18 - - mkdir -p artifacts - pushd artifacts - ${CONDA_RUN} aws s3 sync "${ARTIFACTS_S3_DIR}" . - popd - - ls -lah artifacts - - - name: Download the list of benchmark configs from S3 - env: - BENCHMARK_CONFIGS_DIR: s3://gha-artifacts/${{ github.repository }}/${{ github.run_id }}/artifacts/benchmark-configs/ - shell: bash - run: | - set -eux - mkdir -p benchmark-configs - pushd benchmark-configs - ${CONDA_RUN} aws s3 sync "${BENCHMARK_CONFIGS_DIR}" . - popd - ls -lah benchmark-configs - - - name: Extract the benchmark results JSON - shell: bash - env: - DEVICE_TYPE: ios - run: | - set -eux - - mkdir -p benchmark-results - - for ARTIFACTS_BY_JOB in artifacts/*.json; do - [ -f "${ARTIFACTS_BY_JOB}" ] || break - echo "${ARTIFACTS_BY_JOB}" - ${CONDA_RUN} python .github/scripts/extract_benchmark_results.py \ - --artifacts "${ARTIFACTS_BY_JOB}" \ - --output-dir benchmark-results \ - --app "${DEVICE_TYPE}" \ - --benchmark-configs benchmark-configs - done - - for BENCHMARK_RESULTS in benchmark-results/v3/*.json; do - cat "${BENCHMARK_RESULTS}" - echo - done - - - name: Upload the benchmark results (v3) - uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main - with: - benchmark-results-dir: benchmark-results/v3 - dry-run: false - schema-version: v3 - github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/cuda-perf.yml b/.github/workflows/cuda-perf.yml new file mode 100644 index 00000000000..71e3adf5abc --- /dev/null +++ b/.github/workflows/cuda-perf.yml @@ -0,0 +1,439 @@ +name: cuda-perf + +on: + schedule: + - cron: 0 8 * * * # 12am / 1am PST (8am UTC) + pull_request: + paths: + - .github/workflows/cuda-perf.yml + - .ci/scripts/cuda_benchmark.py + - .ci/scripts/export_model_artifact.sh + - .ci/scripts/test_model_e2e.sh + push: + branches: + - main + paths: + - .github/workflows/cuda-perf.yml + - .ci/scripts/cuda_benchmark.py + - .ci/scripts/export_model_artifact.sh + - .ci/scripts/test_model_e2e.sh + workflow_dispatch: + inputs: + models: + description: Models to be benchmarked (comma-separated HuggingFace model IDs) + required: false + type: string + default: openai/whisper-small + quantizations: + description: Quantization types (comma-separated) + required: false + type: string + default: non-quantized + num_runs: + description: Number of benchmark runs per model + required: false + type: string + default: "50" + run_all_models: + description: Run all available models (overrides models input) + required: false + type: boolean + default: false + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + set-parameters: + runs-on: ubuntu-22.04 + outputs: + benchmark_configs: ${{ steps.set-parameters.outputs.benchmark_configs }} + steps: + - uses: actions/checkout@v3 + with: + submodules: 'false' + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Set parameters + id: set-parameters + shell: bash + env: + # All available models and quantizations + ALL_MODELS: 'mistralai/Voxtral-Mini-3B-2507,openai/whisper-small,openai/whisper-medium,openai/whisper-large-v3-turbo,google/gemma-3-4b-it' + ALL_QUANTIZATIONS: 'non-quantized,quantized-int4-tile-packed,quantized-int4-weight-only' + NUM_RUNS: ${{ inputs.num_runs || '50' }} + RUN_ALL_MODELS: ${{ inputs.run_all_models || 'false' }} + RANDOM_MODEL: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' && 'true' || 'false' }} + run: | + set -eux + + MODELS="${{ inputs.models }}" + QUANTIZATIONS="${{ inputs.quantizations }}" + + # If run_all_models is true, use all models + if [ "$RUN_ALL_MODELS" = "true" ]; then + MODELS="$ALL_MODELS" + echo "Running all available models: $MODELS" + # For non-schedule events (PR, manual trigger without inputs), randomly select one model and one quantization + elif [ -z "$MODELS" ] && [ "${{ github.event_name }}" != "schedule" ]; then + # Split all models into array + IFS=',' read -ra ALL_MODEL_ARRAY <<< "$ALL_MODELS" + # Randomly select one model + RANDOM_MODEL_INDEX=$((RANDOM % ${#ALL_MODEL_ARRAY[@]})) + MODELS="${ALL_MODEL_ARRAY[$RANDOM_MODEL_INDEX]}" + echo "Randomly selected model for PR/push: $MODELS" + elif [ -z "$MODELS" ]; then + # Schedule event: use all models + MODELS="$ALL_MODELS" + fi + + # If run_all_models is true, use all quantizations + if [ "$RUN_ALL_MODELS" = "true" ]; then + QUANTIZATIONS="$ALL_QUANTIZATIONS" + echo "Running all available quantizations: $QUANTIZATIONS" + elif [ -z "$QUANTIZATIONS" ] && [ "${{ github.event_name }}" != "schedule" ]; then + # Split all quantizations into array + IFS=',' read -ra ALL_QUANT_ARRAY <<< "$ALL_QUANTIZATIONS" + # Randomly select one quantization + RANDOM_QUANT_INDEX=$((RANDOM % ${#ALL_QUANT_ARRAY[@]})) + QUANTIZATIONS="${ALL_QUANT_ARRAY[$RANDOM_QUANT_INDEX]}" + echo "Randomly selected quantization for PR/push: $QUANTIZATIONS" + elif [ -z "$QUANTIZATIONS" ]; then + # Schedule event: use all quantizations + QUANTIZATIONS="$ALL_QUANTIZATIONS" + fi + + # Split models and quantizations into arrays + IFS=',' read -ra MODEL_ARRAY <<< "$MODELS" + IFS=',' read -ra QUANT_ARRAY <<< "$QUANTIZATIONS" + + # If random model is requested (for main branch push), select one random model from the already selected models + if [ "$RANDOM_MODEL" = "true" ] && [ ${#MODEL_ARRAY[@]} -gt 1 ]; then + RANDOM_INDEX=$((RANDOM % ${#MODEL_ARRAY[@]})) + MODELS="${MODEL_ARRAY[$RANDOM_INDEX]}" + MODEL_ARRAY=("$MODELS") + echo "Random model selected for main branch push: $MODELS" + fi + + # Generate benchmark configs + CONFIGS='{"include":[' + FIRST=true + for MODEL in "${MODEL_ARRAY[@]}"; do + for QUANT in "${QUANT_ARRAY[@]}"; do + if [ "$FIRST" = true ]; then + FIRST=false + else + CONFIGS+=',' + fi + # Sanitize model name for use in artifact paths + MODEL_SAFE=$(echo "$MODEL" | sed 's/\//_/g') + CONFIGS+="{\"model\":\"$MODEL\",\"quant\":\"$QUANT\",\"model_safe\":\"$MODEL_SAFE\",\"num_runs\":\"$NUM_RUNS\"}" + done + done + CONFIGS+=']}' + + echo "benchmark_configs=$CONFIGS" >> $GITHUB_OUTPUT + echo "Generated benchmark configs:" + echo "$CONFIGS" | python -m json.tool + + export-models: + name: export-models + needs: set-parameters + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + secrets: inherit + strategy: + matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} + fail-fast: false + with: + timeout: 90 + secrets-env: EXECUTORCH_HF_TOKEN + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + use-custom-docker-registry: false + submodules: recursive + upload-artifact: model-${{ matrix.model_safe }}-${{ matrix.quant }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + echo "::group::Setup ExecuTorch" + ./install_executorch.sh + echo "::endgroup::" + + echo "::group::Setup Huggingface" + pip install -U "huggingface_hub[cli]<1.0" accelerate + huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN + OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) + pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} + echo "::endgroup::" + + echo "::group::Exporting model ${{ matrix.model }} with quantization ${{ matrix.quant }}" + OUTPUT_DIR="model_artifacts" + mkdir -p "$OUTPUT_DIR" + + bash .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model }}" "${{ matrix.quant }}" "$OUTPUT_DIR" + + # Move artifacts to RUNNER_ARTIFACT_DIR for upload + mv "$OUTPUT_DIR"/* "${RUNNER_ARTIFACT_DIR}/" + ls -lah "${RUNNER_ARTIFACT_DIR}" + echo "::endgroup::" + + benchmark-cuda: + name: benchmark-cuda + needs: + - set-parameters + - export-models + if: always() + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + matrix: ${{ fromJson(needs.set-parameters.outputs.benchmark_configs) }} + fail-fast: false + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + use-custom-docker-registry: false + submodules: recursive + download-artifact: model-${{ matrix.model_safe }}-${{ matrix.quant }} + upload-artifact: results-${{ matrix.model_safe }}-${{ matrix.quant }} + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + echo "::group::Setup environment" + ./install_requirements.sh + pip list + echo "::endgroup::" + + echo "::group::Prepare model artifacts" + mkdir -p model_artifacts + cp "${RUNNER_ARTIFACT_DIR}/model.pte" model_artifacts/model.pte + cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" model_artifacts/aoti_cuda_blob.ptd + + # Copy additional files if they exist + if [ -f "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" ]; then + cp "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/whisper_preprocessor.pte" ]; then + cp "${RUNNER_ARTIFACT_DIR}/whisper_preprocessor.pte" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/tekken.json" ]; then + cp "${RUNNER_ARTIFACT_DIR}/tekken.json" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/poem.wav" ]; then + cp "${RUNNER_ARTIFACT_DIR}/poem.wav" model_artifacts/ + fi + if [ -f "${RUNNER_ARTIFACT_DIR}/output.wav" ]; then + cp "${RUNNER_ARTIFACT_DIR}/output.wav" model_artifacts/ + fi + # Copy tokenizer files + for file in tokenizer.json tokenizer_config.json special_tokens_map.json; do + if [ -f "${RUNNER_ARTIFACT_DIR}/$file" ]; then + cp "${RUNNER_ARTIFACT_DIR}/$file" model_artifacts/ + fi + done + + ls -lah model_artifacts/ + echo "::endgroup::" + + echo "::group::Build runner" + bash .ci/scripts/test_model_e2e.sh cuda "${{ matrix.model }}" "${{ matrix.quant }}" model_artifacts + echo "::endgroup::" + + echo "::group::Running benchmark for ${{ matrix.model }} (${{ matrix.quant }}) with ${{ matrix.num_runs }} runs" + export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH + + # Get GPU name using nvidia-smi + GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1) + echo "Detected GPU: $GPU_NAME" + + # Get CUDA driver version + CUDA_DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1) + echo "CUDA Driver Version: $CUDA_DRIVER_VERSION" + + # Create results directory (separate from model artifacts) + RESULTS_DIR="benchmark_results" + mkdir -p "$RESULTS_DIR" + + # Determine model name and runner command based on model + case "${{ matrix.model }}" in + mistralai/Voxtral-Mini-3B-2507) + RUNNER="cmake-out/examples/models/voxtral/voxtral_runner" + PREPROCESSOR="model_artifacts/voxtral_preprocessor.pte" + TOKENIZER="model_artifacts/tekken.json" + AUDIO="model_artifacts/poem.wav" + RUNNER_CMD="$RUNNER --model_path model_artifacts/model.pte --data_path model_artifacts/aoti_cuda_blob.ptd --tokenizer_path $TOKENIZER --audio_path $AUDIO --processor_path $PREPROCESSOR --temperature 0" + MODEL_NAME="voxtral_${{ matrix.quant }}" + ;; + openai/whisper-*) + RUNNER="cmake-out/examples/models/whisper/whisper_runner" + PREPROCESSOR="model_artifacts/whisper_preprocessor.pte" + AUDIO="model_artifacts/output.wav" + RUNNER_CMD="$RUNNER --model_path model_artifacts/model.pte --data_path model_artifacts/aoti_cuda_blob.ptd --tokenizer_path model_artifacts/ --audio_path $AUDIO --processor_path $PREPROCESSOR --temperature 0" + MODEL_NAME=$(echo "${{ matrix.model }}" | sed 's/openai\///')_${{ matrix.quant }} + ;; + google/gemma-3-4b-it) + RUNNER="cmake-out/examples/models/gemma3/gemma3_e2e_runner" + IMAGE="docs/source/_static/img/et-logo.png" + RUNNER_CMD="$RUNNER --model_path model_artifacts/model.pte --data_path model_artifacts/aoti_cuda_blob.ptd --tokenizer_path model_artifacts/ --image_path $IMAGE --temperature 0" + MODEL_NAME="gemma3_${{ matrix.quant }}" + ;; + *) + echo "Error: Unsupported model '${{ matrix.model }}'" + exit 1 + ;; + esac + + # Run benchmark using cuda_benchmark.py + python .ci/scripts/cuda_benchmark.py \ + --runner_command "$RUNNER_CMD" \ + --model_name "$MODEL_NAME" \ + --num_runs "${{ matrix.num_runs }}" \ + --output_json "$RESULTS_DIR/benchmark_results.json" \ + --output_v3 "$RESULTS_DIR/benchmark_results_v3.json" \ + --model "${{ matrix.model }}" \ + --quantization "${{ matrix.quant }}" \ + --git_sha "${{ github.sha }}" \ + --workflow_run_id "${{ github.run_id }}" \ + --workflow_run_url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" \ + --gpu_name "$GPU_NAME" \ + --cuda_driver_version "$CUDA_DRIVER_VERSION" + + # Save additional metadata + cat > "$RESULTS_DIR/metadata.json" <&1) - EXIT_CODE=$? - set -e - - echo "$OUTPUT" - - if ! echo "$OUTPUT" | grep -iq "Samantha"; then - echo "Expected output 'Samantha' not found in output" - exit 1 - fi - - if [ $EXIT_CODE -ne 0 ]; then - echo "Unexpected exit code: $EXIT_CODE" - exit $EXIT_CODE - fi - echo "::endgroup::" + ${CONDA_RUN} bash .ci/scripts/test_model_e2e.sh metal "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 72c3dc6222d..c9dd6a0b734 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -13,32 +13,32 @@ concurrency: cancel-in-progress: true jobs: - test-qnn-wheel-packages-linux: - name: test-qnn-wheel-packages-linux - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - strategy: - fail-fast: false - matrix: - python-version: [ "3.10", "3.11", "3.12" ] - with: - runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk - submodules: 'recursive' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 180 - script: | - # The generic Linux job chooses to use base env, not the one setup by the image - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" + # test-qnn-wheel-packages-linux: + # name: test-qnn-wheel-packages-linux + # uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + # permissions: + # id-token: write + # contents: read + # strategy: + # fail-fast: false + # matrix: + # python-version: [ "3.10", "3.11", "3.12" ] + # with: + # runner: linux.2xlarge + # docker-image: ci-image:executorch-ubuntu-22.04-qnn-sdk + # submodules: 'recursive' + # ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + # timeout: 180 + # script: | + # # The generic Linux job chooses to use base env, not the one setup by the image + # CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + # conda activate "${CONDA_ENV}" - # Create a clean env for each python version - conda create -y -n test_env_${{ matrix.python-version }} python=${{ matrix.python-version }} - conda activate test_env_${{ matrix.python-version }} + # # Create a clean env for each python version + # conda create -y -n test_env_${{ matrix.python-version }} python=${{ matrix.python-version }} + # conda activate test_env_${{ matrix.python-version }} - PYTHON_EXECUTABLE=python bash .ci/scripts/test_wheel_package_qnn.sh "${{ matrix.python-version }}" + # PYTHON_EXECUTABLE=python bash .ci/scripts/test_wheel_package_qnn.sh "${{ matrix.python-version }}" test-setup-linux-gcc: name: test-setup-linux-gcc @@ -340,6 +340,7 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 90 script: | + set -eux # The generic Linux job chooses to use base env, not the one setup by the image CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" diff --git a/.github/workflows/test-backend-arm.yml b/.github/workflows/test-backend-arm.yml index 22e3d524f6b..638d5a2079f 100644 --- a/.github/workflows/test-backend-arm.yml +++ b/.github/workflows/test-backend-arm.yml @@ -26,7 +26,7 @@ jobs: uses: ./.github/workflows/_test_backend.yml with: backend: arm - flows: '["arm_tosa_fp", "arm_tosa_int", "arm_ethos_u55", "arm_ethos_u85"]' + flows: '["arm_tosa_fp", "arm_tosa_int", "arm_ethos_u55", "arm_ethos_u85", "arm_vgf_fp", "arm_vgf_int"]' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 120 run-linux: true diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 03a13e3717b..cc918034988 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -317,6 +317,40 @@ jobs: # Test test_arm_baremetal.sh with test backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}" + test-arm-backend-vkml: + name: test-arm-backend-vkml + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + matrix: + include: + - test_arm_baremetal: test_pytest_ops_vkml + fail-fast: false + with: + runner: linux.2xlarge.memory + docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + source .ci/scripts/utils.sh + install_executorch "--use-pt-pinned-commit" + + .ci/scripts/setup-arm-baremetal-tools.sh --disable-ethos-u-deps --enable-mlsdk-deps --install-mlsdk-deps-with-pip + + # Increase number of files user can monitor to bypass buck failures. + # Hopefully this is high enough for this setup. + sudo sysctl fs.inotify.max_user_watches=1048576 # 1024 * 1024 + + ARM_TEST=${{ matrix.test_arm_baremetal }} + + backends/arm/test/test_arm_baremetal.sh "${ARM_TEST}" + test-arm-cortex-m-size-test: name: test-arm-cortex-m-size-test uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -347,7 +381,7 @@ jobs: elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then setup_script_args="--target-toolchain zephyr" toolchain_prefix=arm-zephyr-eabi- - threshold="135656" # 132 KiB + threshold="135768" # 136 KiB toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake else echo "Fail unsupport OS selection ${{ matrix.os }}" @@ -1066,3 +1100,33 @@ jobs: .ci/scripts/test_model.ps1 -modelName ${{ matrix.model }} -backend ${{ matrix.backend }} }" + + test-mcu-cortex-m-backend: + name: test-mcu-cortex-m-backend + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: linux.2xlarge.memory + docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + source .ci/scripts/utils.sh + install_executorch "--use-pt-pinned-commit" + + # Install arm dependencies + .ci/scripts/setup-arm-baremetal-tools.sh + source examples/arm/ethos-u-scratch/setup_path.sh + + # To build cortex-m test runner + backends/cortex_m/test/build_test_runner.sh + + # To run cortex_m tests + pytest --config-file=backends/arm/test/pytest.ini backends/cortex_m/test diff --git a/.lintrunner.toml b/.lintrunner.toml index b366c141799..396b7fde5ac 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -367,7 +367,7 @@ exclude_patterns = [ '**/third-party/**', 'scripts/check_binary_dependencies.py', 'profiler/test/test_profiler_e2e.py', - 'backends/arm/test/**', + 'backends/arm/test/ops/*.py', ] command = [ 'python', @@ -449,3 +449,24 @@ command = [ "--", "@{{PATHSFILE}}", ] + +[[linter]] +code = 'ETVKNODEBUG' +include_patterns = [ + "backends/vulkan/**/*.glsl", +] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'grep_linter', + '--pattern=((DEBUG_MODE)|(GL_EXT_debug_printf))', + '--linter-name=ETVKNODEBUG', + '--error-name=Using DEBUG_MODE or GL_EXT_debug_printf in Vulkan shader', + """--error-description=\ + #define DEBUG_MODE or #extension GL_EXT_debug_printf should only be used during development! + """, + '--', + '@{{PATHSFILE}}', +] diff --git a/.mypy.ini b/.mypy.ini index baea2efefa9..0ce444e8a79 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -24,11 +24,14 @@ files = test, util -mypy_path = executorch +mypy_path = executorch,src [mypy-executorch.backends.*] follow_untyped_imports = True +[mypy-backends.arm.*] +disallow_untyped_decorators = False + [mypy-executorch.codegen.*] follow_untyped_imports = True @@ -74,6 +77,12 @@ ignore_missing_imports = True [mypy-pytorch_sphinx_theme] ignore_missing_imports = True +[mypy-pytorch_sphinx_theme2] +ignore_missing_imports = True + +[mypy-executorch.version] +ignore_missing_imports = True + [mypy-ruamel] ignore_missing_imports = True diff --git a/CMakeLists.txt b/CMakeLists.txt index c6d6f26b41f..e7e1cb96b6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,6 +119,10 @@ if(EXECUTORCH_ENABLE_EVENT_TRACER) add_definitions(-DET_EVENT_TRACER_ENABLED) endif() +if(EXECUTORCH_ENABLE_BUNDLE_IO) + add_definitions(-DET_BUNDLE_IO_ENABLED) +endif() + # -ffunction-sections -fdata-sections: breaks function and data into sections so # they can be properly gc'd. -s: strip symbol. if(WIN32) @@ -591,8 +595,9 @@ endif() if(EXECUTORCH_BUILD_CUDA) # Build CUDA-specific AOTI functionality add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cuda) - # Add aoti_cuda to backends - it already depends on aoti_common - list(APPEND _executorch_backends aoti_cuda) + # Add aoti_cuda_backend to backends - it transitively includes aoti_cuda_shims + # and cuda_platform + list(APPEND _executorch_backends aoti_cuda_backend) endif() if(EXECUTORCH_BUILD_METAL) @@ -800,6 +805,9 @@ if(EXECUTORCH_BUILD_PYBIND) torch ) + # RPATH for _portable_lib.so + set(_portable_lib_rpath "$ORIGIN/../../../torch/lib") + if(EXECUTORCH_BUILD_EXTENSION_MODULE) # Always use static linking for pybindings to avoid runtime symbol # resolution issues @@ -834,6 +842,7 @@ if(EXECUTORCH_BUILD_PYBIND) if(EXECUTORCH_BUILD_QNN) list(APPEND _dep_libs qnn_executorch_backend) + string(APPEND _portable_lib_rpath ":$ORIGIN/../../backends/qualcomm") endif() if(EXECUTORCH_BUILD_ENN) @@ -885,10 +894,11 @@ if(EXECUTORCH_BUILD_PYBIND) target_compile_options(portable_lib PUBLIC ${_pybind_compile_options}) target_link_libraries(portable_lib PRIVATE ${_dep_libs}) - # Set RPATH to find PyTorch libraries relative to the installation location - # This goes from executorch/extension/pybindings up to site-packages, then to - # torch/lib. Don't do this to APPLE, as it will error out on the following - # error: + # Set RPATH to find PyTorch and backend libraries relative to the installation + # location. This goes from executorch/extension/pybindings up to + # site-packages, then to torch/lib. If QNN is enabled, also add + # backends/qualcomm/. Don't do this to APPLE, as it will error out on the + # following error: # if(APPLE) # Skip setting @loader_path for APPLE, since it causes error like ld: @@ -896,8 +906,8 @@ if(EXECUTORCH_BUILD_PYBIND) # libtorch_cpu.dylib' else() set_target_properties( - portable_lib PROPERTIES BUILD_RPATH "$ORIGIN/../../../torch/lib" - INSTALL_RPATH "$ORIGIN/../../../torch/lib" + portable_lib PROPERTIES BUILD_RPATH "${_portable_lib_rpath}" + INSTALL_RPATH "${_portable_lib_rpath}" ) endif() @@ -1077,6 +1087,10 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER) list(APPEND _executor_runner_libs etdump flatccrt) endif() + if(EXECUTORCH_ENABLE_BUNDLE_IO) + list(APPEND _executor_runner_libs bundled_program) + endif() + add_executable(executor_runner ${_executor_runner__srcs}) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(executor_runner) diff --git a/CMakePresets.json b/CMakePresets.json index 379f4f418ed..12e398b4fe4 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -119,38 +119,118 @@ } }, { - "name": "llm", - "displayName": "Build LLM libraries", - "inherits": ["common"], - "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" - }, - "condition": { - "type": "inList", - "string": "${hostSystemName}", - "list": ["Darwin", "Linux", "Windows"] - } + "name": "llm", + "displayName": "Build LLM libraries", + "inherits": [ + "common" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/llm.cmake", + "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Darwin", "Linux", "Windows"] + } }, { - "name": "profiling", - "displayName": "Build ExecuTorch with Profiling Enabled", - "inherits": [ - "common" - ], - "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/profiling.cmake", - "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" - }, - "condition": { - "type": "inList", - "string": "${hostSystemName}", - "list": [ - "Darwin", - "Linux", - "Windows" - ] - } + "name": "llm-release", + "displayName": "LLM release build", + "inherits": [ + "llm" + ], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out" + } + }, + { + "name": "llm-release-cuda", + "displayName": "LLM release build with CUDA", + "inherits": [ + "llm-release" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + }, + { + "name": "llm-release-metal", + "displayName": "LLM release build with Metal", + "inherits": [ + "llm-release" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "llm-debug", + "displayName": "LLM debug build", + "inherits": [ + "llm" + ], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_INSTALL_PREFIX": "${sourceDir}/cmake-out" + } + }, + { + "name": "llm-debug-cuda", + "displayName": "LLM debug build with CUDA", + "inherits": [ + "llm-debug" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + }, + { + "name": "llm-debug-metal", + "displayName": "LLM debug build with Metal", + "inherits": [ + "llm-debug" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + }, + { + "name": "profiling", + "displayName": "Build ExecuTorch with Profiling Enabled", + "inherits": [ + "common" + ], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/profiling.cmake", + "CMAKE_OSX_DEPLOYMENT_TARGET": "12.0" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Darwin", "Linux", "Windows"] + } }, { "name": "windows", @@ -177,13 +257,155 @@ } }, { - "name": "arm-baremetal", - "displayName": "Build ExecuTorch for Arm baremetal", - "inherits": ["common"], - "cacheVariables": { - "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_baremetal.cmake", - "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake" - } + "name": "arm-baremetal", + "displayName": "Build ExecuTorch for Arm baremetal", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/arm_baremetal.cmake", + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake" + } + } + ], + "buildPresets": [ + { + "name": "llm-release-install", + "displayName": "Build and install LLM extension release artifacts", + "configurePreset": "llm-release", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-release-cuda-install", + "displayName": "Build and install LLM extension release artifacts (CUDA)", + "configurePreset": "llm-release-cuda", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-release-metal-install", + "displayName": "Build and install LLM extension release artifacts (Metal)", + "configurePreset": "llm-release-metal", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-debug-install", + "displayName": "Build and install LLM extension debug artifacts", + "configurePreset": "llm-debug", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-debug-cuda-install", + "displayName": "Build and install LLM extension debug artifacts (CUDA)", + "configurePreset": "llm-debug-cuda", + "targets": [ + "install" + ], + "jobs": 0 + }, + { + "name": "llm-debug-metal-install", + "displayName": "Build and install LLM extension debug artifacts (Metal)", + "configurePreset": "llm-debug-metal", + "targets": [ + "install" + ], + "jobs": 0 + } + ], + "workflowPresets": [ + { + "name": "llm-release", + "displayName": "Configure, build and install ExecuTorch LLM extension with default CPU backend", + "steps": [ + { + "type": "configure", + "name": "llm-release" + }, + { + "type": "build", + "name": "llm-release-install" + } + ] + }, + { + "name": "llm-release-cuda", + "displayName": "Configure, build and install ExecuTorch LLM extension with CUDA enabled", + "steps": [ + { + "type": "configure", + "name": "llm-release-cuda" + }, + { + "type": "build", + "name": "llm-release-cuda-install" + } + ] + }, + { + "name": "llm-release-metal", + "displayName": "Configure, build and install ExecuTorch LLM extension with Metal enabled", + "steps": [ + { + "type": "configure", + "name": "llm-release-metal" + }, + { + "type": "build", + "name": "llm-release-metal-install" + } + ] + }, + { + "name": "llm-debug", + "displayName": "Configure, build and install ExecuTorch LLM extension with default CPU backend (Debug)", + "steps": [ + { + "type": "configure", + "name": "llm-debug" + }, + { + "type": "build", + "name": "llm-debug-install" + } + ] + }, + { + "name": "llm-debug-cuda", + "displayName": "Configure, build and install ExecuTorch LLM extension with CUDA enabled (Debug)", + "steps": [ + { + "type": "configure", + "name": "llm-debug-cuda" + }, + { + "type": "build", + "name": "llm-debug-cuda-install" + } + ] + }, + { + "name": "llm-debug-metal", + "displayName": "Configure, build and install ExecuTorch LLM extension with Metal enabled (Debug)", + "steps": [ + { + "type": "configure", + "name": "llm-debug-metal" + }, + { + "type": "build", + "name": "llm-debug-metal-install" + } + ] } ] } diff --git a/Makefile b/Makefile new file mode 100644 index 00000000000..bc42e8beaf9 --- /dev/null +++ b/Makefile @@ -0,0 +1,197 @@ +# ============================================================================== +# ExecuTorch Targets Makefile +# ============================================================================== +# +# This Makefile provides convenient targets for building ExecuTorch model runners +# with different backend configurations (CPU, CUDA, Metal), as well as other +# binary targets. +# +# WHAT THIS BUILDS: +# ----------------- +# Each target builds: +# 1. ExecuTorch core libraries with the specified backend (CPU, CUDA, or Metal) +# 2. The model-specific runner executable in cmake-out/examples/models// +# +# SUPPORTED MODELS: +# ----------------- +# - voxtral: Multimodal voice + text model (CPU, CUDA, Metal) +# - whisper: Speech recognition model (CPU, CUDA, Metal) +# - llama: Text generation model (CPU) +# - llava: Vision + language model (CPU) +# - gemma3: Text generation model (CPU, CUDA) +# +# USAGE: +# ------ +# make - # Build a specific model with a backend +# make help # Show all available targets +# make clean # Remove all build artifacts +# +# Examples: +# make voxtral-cuda # Build Voxtral with CUDA backend +# make llama-cpu # Build Llama with CPU backend +# make whisper-metal # Build Whisper with Metal backend (macOS) +# +# HOW TO ADD A NEW MODEL: +# ----------------------- +# To add a new model (e.g., "mymodel"), follow these steps: +# +# 1. Create a CMakePresets.json in examples/models/mymodel/: +# - Define configurePresets for each backend (base, cpu, cuda, metal) +# - Define buildPresets with the target name from CMakeLists.txt +# - Define workflowPresets that combine configure + build steps +# - See examples/models/voxtral/CMakePresets.json for multi-backend reference +# - Or see examples/models/llama/CMakePresets.json for simple single-preset reference +# +# 2. Add targets to this Makefile: +# a) Add to .PHONY declaration: mymodel-cuda mymodel-cpu mymodel-metal +# b) Add help text in the help target +# c) Add target implementations following this pattern: +# +# mymodel-cuda: +# @echo "==> Building and installing ExecuTorch with CUDA..." +# cmake --workflow --preset llm-release-cuda +# @echo "==> Building MyModel runner with CUDA..." +# cd examples/models/mymodel && cmake --workflow --preset mymodel-cuda +# @echo "" +# @echo "✓ Build complete!" +# @echo " Binary: cmake-out/examples/models/mymodel/mymodel_runner" +# +# mymodel-cpu: +# @echo "==> Building and installing ExecuTorch..." +# cmake --workflow --preset llm-release +# @echo "==> Building MyModel runner (CPU)..." +# cd examples/models/mymodel && cmake --workflow --preset mymodel-cpu +# @echo "" +# @echo "✓ Build complete!" +# @echo " Binary: cmake-out/examples/models/mymodel/mymodel_runner" +# +# mymodel-metal: +# @echo "==> Building and installing ExecuTorch with Metal..." +# cmake --workflow --preset llm-release-metal +# @echo "==> Building MyModel runner with Metal..." +# cd examples/models/mymodel && cmake --workflow --preset mymodel-metal +# @echo "" +# @echo "✓ Build complete!" +# @echo " Binary: cmake-out/examples/models/mymodel/mymodel_runner" +# +# 3. Test your new targets: +# make mymodel-cpu # or mymodel-cuda, mymodel-metal +# +# NOTES: +# ------ +# - CUDA backend is only available on Linux systems +# - Metal backend is only available on macOS (Darwin) systems +# - Some models may not support all backends (check model documentation) +# - Binary outputs are located in cmake-out/examples/models// +# - The preset names in CMakePresets.json must match the names used in Makefile +# +# ============================================================================== + +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cpu whisper-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help + +help: + @echo "This Makefile adds targets to build runners for various models on various backends. Run using `make `. Available targets:" + @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" + @echo " voxtral-cpu - Build Voxtral runner with CPU backend" + @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " whisper-cuda - Build Whisper runner with CUDA backend" + @echo " whisper-cpu - Build Whisper runner with CPU backend" + @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" + @echo " llama-cpu - Build Llama runner with CPU backend" + @echo " llava-cpu - Build Llava runner with CPU backend" + @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" + @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " clean - Clean build artifacts" + +voxtral-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Voxtral runner with CUDA..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + +voxtral-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Voxtral runner (CPU)..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + +voxtral-metal: + @echo "==> Building and installing ExecuTorch with Metal..." + cmake --workflow --preset llm-release-metal + @echo "==> Building Voxtral runner with Metal..." + cd examples/models/voxtral && cmake --workflow --preset voxtral-metal + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral/voxtral_runner" + +whisper-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Whisper runner with CUDA..." + cd examples/models/whisper && cmake --workflow --preset whisper-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + +whisper-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Whisper runner (CPU)..." + cd examples/models/whisper && cmake --workflow --preset whisper-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + +whisper-metal: + @echo "==> Building and installing ExecuTorch with Metal..." + cmake --workflow --preset llm-release-metal + @echo "==> Building Whisper runner with Metal..." + cd examples/models/whisper && cmake --workflow --preset whisper-metal + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + +llama-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Llama runner (CPU)..." + cd examples/models/llama && cmake --workflow --preset llama-release + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/llama/llama_main" + +llava-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Llava runner (CPU)..." + cd examples/models/llava && cmake --workflow --preset llava + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/llava/llava_main" + +gemma3-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Gemma3 runner with CUDA..." + cd examples/models/gemma3 && cmake --workflow --preset gemma3-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma3/gemma3_e2e_runner" + +gemma3-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Gemma3 runner (CPU)..." + cd examples/models/gemma3 && cmake --workflow --preset gemma3-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/gemma3/gemma3_e2e_runner" + +clean: + rm -rf cmake-out diff --git a/backends/aoti/CMakeLists.txt b/backends/aoti/CMakeLists.txt index bcff1d56769..d5582dfe7c7 100644 --- a/backends/aoti/CMakeLists.txt +++ b/backends/aoti/CMakeLists.txt @@ -38,6 +38,9 @@ target_compile_options( PUBLIC $<$:/EHsc /GR> $<$>:-fexceptions -frtti -fPIC> ) +target_compile_definitions( + aoti_common PRIVATE $<$:EXPORT_AOTI_FUNCTIONS> +) # Ensure symbols are exported properly if(APPLE) target_link_options(aoti_common PUBLIC -Wl,-export_dynamic) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py new file mode 100644 index 00000000000..2d396a296bd --- /dev/null +++ b/backends/aoti/aoti_backend.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +import torch +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +@experimental( + "This API and all of aoti-driven backend related functionality are experimental." +) +class AotiBackend(ABC): + """ + Base mixin class for AOTInductor-based backends. + + This class provides common functionality for compiling models using AOTInductor + with different device targets (CUDA, Metal, etc.). + + This is a mixin class, not an actual backend object, for aoti-driven backends. + Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both + BackendDetails and AotiBackend to get the full functionality. + """ + + @classmethod + @abstractmethod + def get_device_name(cls) -> str: + """Return the device name for this backend (e.g., 'cuda', 'metal').""" + pass + + @classmethod + @abstractmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + """Return the set of supported fallback kernels for this backend.""" + pass + + @classmethod + @abstractmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + """Return the decomposition table for this backend.""" + pass + + @classmethod + @abstractmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Return the AOTInductor compilation options for this backend.""" + pass + + @classmethod + @abstractmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition.""" + pass + + @classmethod + @contextlib.contextmanager + def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): + """ + Context manager to collect unsupported fallback kernels during compilation. + Monitors both extern kernel calls and runtime lookup. + """ + supported_kernels = cls.get_supported_fallback_kernels() + + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + original_generate_fallback_kernel_with_runtime_lookup_aot = ( + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, + kernel, + args, + device, + debug_args=debug_args, + debug_handle=debug_handle, + ) + + def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( + self, + op_overload, + raw_args, + output_args, + raw_outputs, + ): + kernel_name = getattr(op_overload, "_name", str(op_overload)) + if kernel_name not in supported_kernels: + missing_fallback_kernels.add(kernel_name) + + original_generate_fallback_kernel_with_runtime_lookup_aot( + self, op_overload, raw_args, output_args, raw_outputs + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels + + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( + original_generate_fallback_kernel_with_runtime_lookup_aot + ) + + @classmethod + def preprocess( + cls, + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Preprocess the edge program and compile it using AOTInductor. + Weights are always separated from the SO file. + """ + device_name = cls.get_device_name() + decomposition_table = cls.get_decomposition_table() + options = cls.get_aoti_compile_options(compile_specs) + + # Move the edge_program to the target device + device_edge_program = move_to_device_pass( + edge_program, device_name if device_name != "metal" else "mps" + ) + + # Replace view_copy with view + ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) + + # Apply custom backend-specific passes + custom_passes = cls.get_custom_passes() + for custom_pass in custom_passes: + custom_pass(device_edge_program.graph_module) + + # Run decompositions if any + if decomposition_table: + device_edge_program = device_edge_program.run_decompositions( + decomposition_table + ) + + edge_program_module = device_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = device_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in device_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Track missing fallback kernels + missing_fallback_kernels: Set[str] = set() + + # Compile with fallback kernel collection + with cls.collect_unsupported_fallback_kernels( + missing_fallback_kernels + ), torch.no_grad(): + paths = torch._inductor.aot_compile( + edge_program_module, tuple(user_input_placeholders), options=options + ) + + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + method_name = cls.method_name_from_compile_specs(compile_specs) + raise RuntimeError( + f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # Extract paths - weights are always separated + so_path = None + blob_path = None + + if isinstance(paths, list): + for path in paths: + if path.endswith(".wrapper.so"): + so_path = path + elif path.endswith(".wrapper_weights.blob"): + blob_path = path + else: + so_path = paths + + if so_path is None or blob_path is None: + raise RuntimeError( + f"Could not find required files in compiled paths, got {paths}" + ) + + # Read SO file + with open(so_path, "rb") as f: + so_data = f.read() + + # Read weights blob + with open(blob_path, "rb") as f: + blob_data = f.read() + + # Create named data store + named_data_store = NamedDataStore() + method_name = cls.method_name_from_compile_specs(compile_specs) + + # Add SO and weights blob separately + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) + weights_blob_data_type = f"aoti_{device_name}_blob" + named_data_store.add_named_data( + method_name + "_weights_blob", blob_data, 1, weights_blob_data_type + ) + + # Clean up the generated files + os.remove(so_path) + os.remove(blob_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @classmethod + def generate_method_name_compile_spec( + cls, + method_name: str, + ) -> CompileSpec: + """ + Generate a CompileSpec for the given method name. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @classmethod + def method_name_from_compile_specs( + cls, + compile_specs: List[CompileSpec], + ) -> str: + """ + Extract the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/aoti/aoti_partitioner.py b/backends/aoti/aoti_partitioner.py index 499bc57b735..aa56d3507e9 100644 --- a/backends/aoti/aoti_partitioner.py +++ b/backends/aoti/aoti_partitioner.py @@ -52,10 +52,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tags: Dict[str, DelegationSpec] = {} tag = "tag0" + # Tag torch.cond and other control flow operations + def is_control_flow(node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + torch.ops.higher_order.cond, + torch.ops.higher_order.map_impl, + torch.ops.higher_order.while_loop, + ] + for node in exported_program.graph.nodes: - if node.op != "call_function": - continue - node.meta["delegation_tag"] = tag + if node.op == "call_function": + node.meta["delegation_tag"] = tag + # Tag get_attr nodes that are used by control flow operations + elif node.op == "get_attr": + # Check if any user is a control flow operation + for user in node.users: + if is_control_flow(user): + node.meta["delegation_tag"] = tag + break partition_tags[tag] = self.delegation_spec diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index deb10478778..abfde86db6d 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -16,8 +16,10 @@ namespace aoti { namespace internal { // Global storage for tensor metadata -std::unordered_map> tensor_to_sizes; -std::unordered_map> tensor_to_strides; +AOTI_SHIM_EXPORT std::unordered_map> + tensor_to_sizes; +AOTI_SHIM_EXPORT std::unordered_map> + tensor_to_strides; } // namespace internal extern "C" { @@ -204,6 +206,61 @@ void cleanup_tensor_metadata() { internal::tensor_to_strides.clear(); } +AOTI_SHIM_EXPORT void aoti_torch_warn( + const char* func, + const char* file, + uint32_t line, + const char* msg) { + ET_LOG(Error, "[%s:%u] %s: %s", file, line, func, msg); +} + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) { + (void)tensor; + (void)ret_size; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) { + (void)self; + (void)ret_new_tensor; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) { + (void)self; + (void)ret_new_tensor; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void* data_ptr, + int64_t ndim, + const int64_t* sizes, + const int64_t* strides, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor) { + (void)data_ptr; + (void)ndim; + (void)sizes; + (void)strides; + (void)storage_offset; + (void)dtype; + (void)device_type; + (void)device_index; + (void)ret_new_tensor; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + } // extern "C" } // namespace aoti diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 91bb785b684..675a9864e74 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -23,57 +24,86 @@ namespace aoti { using executorch::runtime::Error; using executorch::runtime::etensor::Tensor; +// Global storage for tensor metadata +extern std::unordered_map> tensor_to_sizes; +extern std::unordered_map> tensor_to_strides; + extern "C" { // Common AOTI type aliases using AOTIRuntimeError = Error; using AOTITorchError = Error; -// Global storage for tensor metadata -extern std::unordered_map> tensor_to_sizes; -extern std::unordered_map> tensor_to_strides; - // Attribute-related operations (memory-irrelevant) -AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr); -AOTITorchError aoti_torch_get_storage_offset( - Tensor* tensor, - int64_t* ret_storage_offset); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset); -AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides); -AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype); -AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes); -AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); -AOTITorchError aoti_torch_get_device_index( - Tensor* tensor, - int32_t* ret_device_index); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index); -AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); // Utility functions for device and layout information -int32_t aoti_torch_device_type_cpu(); -int32_t aoti_torch_layout_strided(); -int32_t aoti_torch_dtype_float32(); -int32_t aoti_torch_dtype_bfloat16(); -int32_t aoti_torch_dtype_int8(); -int32_t aoti_torch_dtype_int16(); -int32_t aoti_torch_dtype_int32(); -int32_t aoti_torch_dtype_int64(); -int32_t aoti_torch_dtype_bool(); +AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); +AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64(); // Dtype utility function needed by Metal backend -size_t aoti_torch_dtype_element_size(int32_t dtype); +AOTI_SHIM_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype); // Autograd mode functions -int32_t aoti_torch_grad_mode_is_enabled(); -void aoti_torch_grad_mode_set_enabled(bool enabled); +AOTI_SHIM_EXPORT int32_t aoti_torch_grad_mode_is_enabled(); +AOTI_SHIM_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); // Cleanup functions for clearing global state -void cleanup_tensor_metadata(); +AOTI_SHIM_EXPORT void cleanup_tensor_metadata(); + +AOTI_SHIM_EXPORT void aoti_torch_warn( + const char* func, + const char* file, + uint32_t line, + const char* msg); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor); + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor); + +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void* data_ptr, + int64_t ndim, + const int64_t* sizes, + const int64_t* strides, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor); } // extern "C" diff --git a/backends/aoti/export.h b/backends/aoti/export.h new file mode 100644 index 00000000000..7c945f405b0 --- /dev/null +++ b/backends/aoti/export.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Define export macro for Windows DLL +// When building the aoti_cuda_backend library, EXPORT_AOTI_FUNCTIONS is defined +// by CMake, which causes this macro to export symbols using +// __declspec(dllexport). When consuming the library, the macro imports symbols +// using +// __declspec(dllimport). On non-Windows platforms, the macro is empty and has +// no effect. +#ifdef _WIN32 +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_SHIM_EXPORT __declspec(dllexport) +#else +#define AOTI_SHIM_EXPORT __declspec(dllimport) +#endif +#else +#define AOTI_SHIM_EXPORT +#endif diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index 560cf52e06f..327bef8cc53 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -16,6 +16,23 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "aoti_backend", + srcs = [ + "aoti_backend.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", + "//executorch/exir/_serialize:lib", + "//executorch/exir/backend:backend_details", + "//executorch/exir/backend:compile_spec_schema", + ], + ) + # AOTI common shims functionality runtime.cxx_library( name = "common_shims", @@ -24,6 +41,7 @@ def define_common_targets(): ], headers = [ "common_shims.h", + "export.h", "utils.h", ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 53ac436fe38..29c7120feb7 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -20,8 +20,6 @@ NUM_TO_TORCH_DTYPE, split, to, - transpose, - unbind, ) from coremltools.converters.mil.frontend.torch.torch_op_registry import ( register_torch_op, @@ -30,18 +28,6 @@ from executorch.exir.dim_order_utils import get_memory_format -# https://github.com/apple/coremltools/pull/2556 -@register_torch_op(override=False) -def transpose_copy(context, node): - transpose(context, node) - - -# https://github.com/apple/coremltools/pull/2557 -@register_torch_op(override=False) -def unbind_copy(context, node): - unbind(context, node) - - # https://github.com/apple/coremltools/pull/2563 @register_torch_op(override=False) def split_copy(context, node): @@ -117,7 +103,9 @@ def _clone_dim_order(context, node): # https://github.com/apple/coremltools/pull/2558 @register_torch_op( torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], - override=False, + # coremltools did not merge the fix into 9.0 (https://github.com/apple/coremltools/pull/2589), + # so we override here + override=True, ) def dequantize_affine(context, node): inputs = _get_inputs(context, node, expected=[7, 8]) diff --git a/backends/apple/coreml/scripts/build_tests.sh b/backends/apple/coreml/scripts/build_tests.sh index 190adf1f65a..0203e5027a2 100755 --- a/backends/apple/coreml/scripts/build_tests.sh +++ b/backends/apple/coreml/scripts/build_tests.sh @@ -30,7 +30,8 @@ rm -rf "$CMAKE_EXECUTORCH_BUILD_DIR_PATH" cmake "$EXECUTORCH_ROOT_PATH" -B"$CMAKE_EXECUTORCH_BUILD_DIR_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$IOS_TOOLCHAIN_PATH" \ --DPLATFORM=MAC_UNIVERSAL \ +-DPLATFORM=MAC_ARM64 \ +-DCMAKE_OSX_ARCHITECTURES=arm64 \ -DDEPLOYMENT_TARGET=13.0 \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \ -DEXECUTORCH_BUILD_XNNPACK=OFF @@ -44,7 +45,8 @@ rm -rf "$CMAKE_PROTOBUF_BUILD_DIR_PATH" cmake "$PROTOBUF_DIR_PATH/cmake" -B"$CMAKE_PROTOBUF_BUILD_DIR_PATH" \ -DCMAKE_TOOLCHAIN_FILE="$IOS_TOOLCHAIN_PATH" \ --DPLATFORM=MAC_UNIVERSAL \ +-DPLATFORM=MAC_ARM64 \ +-DCMAKE_OSX_ARCHITECTURES=arm64 \ -DDEPLOYMENT_TARGET=13.0 \ -Dprotobuf_BUILD_TESTS=OFF \ -Dprotobuf_BUILD_EXAMPLES=OFF \ @@ -55,7 +57,8 @@ cmake --build "$CMAKE_PROTOBUF_BUILD_DIR_PATH" -j9 -t libprotobuf-lite # Copy required libraries echo "ExecuTorch: Copying libraries" -mkdir "$LIBRARIES_DIR_PATH" +rm -rf $LIBRARIES_DIR_PATH +mkdir -p "$LIBRARIES_DIR_PATH" cp -f "$CMAKE_EXECUTORCH_BUILD_DIR_PATH/libexecutorch.a" "$LIBRARIES_DIR_PATH" cp -f "$CMAKE_EXECUTORCH_BUILD_DIR_PATH/libexecutorch_core.a" "$LIBRARIES_DIR_PATH" cp -f "$CMAKE_PROTOBUF_BUILD_DIR_PATH/libprotobuf-lite.a" "$LIBRARIES_DIR_PATH" diff --git a/backends/apple/coreml/scripts/generate_test_models.sh b/backends/apple/coreml/scripts/generate_test_models.sh index 6a73d697379..bb5de781b5e 100755 --- a/backends/apple/coreml/scripts/generate_test_models.sh +++ b/backends/apple/coreml/scripts/generate_test_models.sh @@ -15,7 +15,9 @@ COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/backends/apple/coreml" cd "$EXECUTORCH_ROOT_PATH" -mkdir "$COREML_DIR_PATH/runtime/test/models/" +rm -rf "$COREML_DIR_PATH/runtime/test/models/" +mkdir -p "$COREML_DIR_PATH/runtime/test/models/" + #Generate models cd "$EXECUTORCH_ROOT_PATH" diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index 5ec1ea6a1de..f57df535d86 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -5,6 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +set -euo pipefail + SCRIPT_DIR_PATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 pwd -P @@ -12,10 +14,16 @@ SCRIPT_DIR_PATH="$( # TODO(jathu): remove the need to fetch coremltools to build deps for coreml_executor_runner. # Keep this version in sync with: pyproject.toml -COREMLTOOLS_VERSION="9.0b1" +COREMLTOOLS_VERSION="9.0" -red=`tput setaf 1` -green=`tput setaf 2` +# Safe colors (no TERM noise in CI) +if command -v tput >/dev/null 2>&1 && [ -t 1 ] && [ -n "${TERM:-}" ]; then + red="$(tput setaf 1)" + green="$(tput setaf 2)" + reset="$(tput sgr0)" +else + red=""; green=""; reset="" +fi EXECUTORCH_ROOT_PATH=$(realpath "$SCRIPT_DIR_PATH/../../../../") COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/backends/apple/coreml" @@ -25,30 +33,79 @@ PROTOBUF_FILES_DIR_PATH="$COREMLTOOLS_DIR_PATH/build/mlmodel/format/" cd "$EXECUTORCH_ROOT_PATH" rm -rf "$COREML_DIR_PATH/third-party" -mkdir "$COREML_DIR_PATH/third-party" +mkdir -p "$COREML_DIR_PATH/third-party" -echo "${green}ExecuTorch: Cloning coremltools." -git clone --depth 1 --branch "${COREMLTOOLS_VERSION}" "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH -cd $COREMLTOOLS_DIR_PATH +echo "${green}ExecuTorch: Cloning coremltools.${reset}" +git clone --depth 1 --branch "${COREMLTOOLS_VERSION}" "https://github.com/apple/coremltools.git" "$COREMLTOOLS_DIR_PATH" +cd "$COREMLTOOLS_DIR_PATH" STATUS=$? if [ $STATUS -ne 0 ]; then - echo "${red}ExecuTorch: Failed to clone coremltools." + echo "${red}ExecuTorch: Failed to clone coremltools.${reset}" exit 1 fi -echo "${green}ExecuTorch: Installing coremltools dependencies." -pip install -r "$COREMLTOOLS_DIR_PATH/reqs/build.pip" +# --------------------------------------------------------------------- +# Host toolchain / SDK setup JUST for coremltools build +# --------------------------------------------------------------------- +HOST_SDKROOT="${SDKROOT:-}" +HOST_CC="${CC:-}" +HOST_CXX="${CXX:-}" +HOST_CFLAGS="${CFLAGS:-}" +HOST_CXXFLAGS="${CXXFLAGS:-}" + +if [[ "$(uname)" == "Darwin" ]]; then + # Only pick macOS SDK if nothing else is specified + if [[ -z "$HOST_SDKROOT" ]]; then + HOST_SDKROOT="$(xcrun --sdk macosx --show-sdk-path)" + fi + if [[ -z "$HOST_CC" ]]; then + HOST_CC="$(xcrun --find clang)" + fi + if [[ -z "$HOST_CXX" ]]; then + HOST_CXX="$(xcrun --find clang++)" + fi + # Only add -isysroot if caller didn't already set CFLAGS/CXXFLAGS + if [[ -z "$HOST_CFLAGS" && -n "$HOST_SDKROOT" ]]; then + HOST_CFLAGS="-isysroot ${HOST_SDKROOT}" + fi + if [[ -z "$HOST_CXXFLAGS" && -n "$HOST_SDKROOT" ]]; then + HOST_CXXFLAGS="-isysroot ${HOST_SDKROOT}" + fi +fi + +echo "${green}ExecuTorch: Installing coremltools dependencies.${reset}" +SDKROOT="$HOST_SDKROOT" \ +CC="$HOST_CC" \ +CXX="$HOST_CXX" \ +CFLAGS="$HOST_CFLAGS" \ +CXXFLAGS="$HOST_CXXFLAGS" \ +python -m pip install -r "$COREMLTOOLS_DIR_PATH/reqs/build.pip" STATUS=$? if [ $STATUS -ne 0 ]; then - echo "${red}ExecuTorch: Failed to install coremltools dependencies." + echo "${red}ExecuTorch: Failed to install coremltools dependencies.${reset}" exit 1 fi -mkdir "$COREMLTOOLS_DIR_PATH/build" +mkdir -p "$COREMLTOOLS_DIR_PATH/build" + +echo "${green}ExecuTorch: Configuring coremltools CMake build.${reset}" +SDKROOT="$HOST_SDKROOT" \ +CC="$HOST_CC" \ +CXX="$HOST_CXX" \ +CFLAGS="$HOST_CFLAGS" \ +CXXFLAGS="$HOST_CXXFLAGS" \ cmake -S "$COREMLTOOLS_DIR_PATH" -B "$COREMLTOOLS_DIR_PATH/build" + +echo "${green}ExecuTorch: Building mlmodel target.${reset}" +SDKROOT="$HOST_SDKROOT" \ +CC="$HOST_CC" \ +CXX="$HOST_CXX" \ +CFLAGS="$HOST_CFLAGS" \ +CXXFLAGS="$HOST_CXXFLAGS" \ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel --target mlmodel -echo "${green}ExecuTorch: Copying protobuf files." +echo "${green}ExecuTorch: Copying protobuf files.${reset}" +rm -rf "$COREML_DIR_PATH/runtime/sdk/format/" mkdir -p "$COREML_DIR_PATH/runtime/sdk/format/" cp -rf "$PROTOBUF_FILES_DIR_PATH" "$COREML_DIR_PATH/runtime/sdk/format/" diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py index 303d8cb78ed..98d240d74b5 100644 --- a/backends/apple/coreml/test/test_coreml_recipes.py +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -326,7 +326,7 @@ def forward(self, x): ) self.check_fully_delegated(session) - self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-3) + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2) self._compare_eager_unquantized_model_outputs(session, model, example_inputs) def test_int8_weight_only_pt2e(self): diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 7d1a5496be3..1b27b027fc2 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -4,107 +4,55 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum +from typing import Any, Dict, final, List -from typing import Any, Dict, final, List, Optional, Set - -import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) -from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu -from torch.export.passes import move_to_device_pass - - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "aoti_torch_mps_convolution": None, - "aoti_torch_mps_mm_out": None, - "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) @final @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalBackend(BackendDetails): - @staticmethod - def preprocess( - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - print("entering the lowerable parts in MetalBackend.preprocess....") - # Move the edge_program from CPU to MPS for aoti compile - mps_edge_program = move_to_device_pass(edge_program, "mps") - - # replace slice_copy with slice - ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module) - - edge_program_module = mps_edge_program.module() - - # Grab all input placeholders from the graph - user_input_names = mps_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in mps_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) +class MetalBackend(AotiBackend, BackendDetails): + """ + MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate + optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices + using the Executorch runtime. + """ + + @classmethod + def get_device_name(cls) -> str: + return "metal" + + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + return { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, + } - # Base options for all devices - options: dict[str, typing.Any] = { + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + return {} + + @classmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return Metal-specific passes (currently none)""" + return [] + + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Get AOTI compile options for Metal backend.""" + _ = compile_specs # Unused, but required by interface + return { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, # Separate weight constants from the .so file @@ -117,83 +65,3 @@ def preprocess( # "aot_inductor.debug_compile": True, # "aot_inductor.force_mmap_weights": False, } - - with collect_unsupported_fallback_kernels(): - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = MetalBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob" - ) - - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Generates a CompileSpec for the given method name. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index b5d2d3161ae..ebb5b7642e1 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -506,6 +506,15 @@ AOTITorchError aoti_torch__reinterpret_tensor( return Error::Ok; } +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle) { + (void)orig_handle; + (void)new_handle; + throw std::runtime_error("Not implemented"); + return Error::Internal; +} + // Cleanup function for clearing global state void cleanup_memory() { // Use aoti_torch_delete_tensor_object to properly delete each tensor diff --git a/backends/apple/metal/runtime/shims/memory.h b/backends/apple/metal/runtime/shims/memory.h index 5f48fd921c6..dda0e6bd6c7 100644 --- a/backends/apple/metal/runtime/shims/memory.h +++ b/backends/apple/metal/runtime/shims/memory.h @@ -64,6 +64,10 @@ AOTITorchError aoti_torch__reinterpret_tensor( int64_t storage_offset, AOTITensorHandle* ret_new_tensor); +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle); + void cleanup_memory(); } // extern "C" diff --git a/backends/arm/CMakeLists.txt b/backends/arm/CMakeLists.txt index ede7a96a389..ac9df82315e 100644 --- a/backends/arm/CMakeLists.txt +++ b/backends/arm/CMakeLists.txt @@ -48,17 +48,44 @@ endif() # VGF backend builds if(EXECUTORCH_BUILD_VGF) - - # include libvgf - set(LIBVGF_PATH - "${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/sw/vgf-lib/" - ) - set(VULKAN_THIRD_PARTY_PATH ${EXECUTORCH_ROOT}/backends/vulkan/third-party) set(VULKAN_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/Vulkan-Headers/include) set(VOLK_HEADERS_PATH ${VULKAN_THIRD_PARTY_PATH}/volk) - set(LIBVGF_STATIC "${LIBVGF_PATH}/build/src/libvgf.a") + if(APPLE + OR CMAKE_SYSTEM_PROCESSOR MATCHES "^(arm64|aarch64)$" + OR EXISTS + "${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/" + ) + message(STATUS "libvgf sourced from local scratch tree") + + # Legacy layout: libvgf sourced from local scratch tree + set(LIBVGF_PATH + "${EXECUTORCH_ROOT}/examples/arm/ethos-u-scratch/ml-sdk-for-vulkan-manifest/sw/vgf-lib/" + ) + set(LIBVGF_STATIC "${LIBVGF_PATH}/build/src/libvgf.a") + else() + message(STATUS "libvgf installed from pip package") + + set(Python3_FIND_VIRTUALENV FIRST) + if(EXECUTORCH_ROOT AND EXISTS "${EXECUTORCH_ROOT}/env") + set(Python3_EXECUTABLE "${EXECUTORCH_ROOT}/env/bin/python3") + endif() + + find_package(Python3 REQUIRED COMPONENTS Interpreter) + + # Prefer arch-specific site-packages if present, else pure + set(_vgf_site_arch "${Python3_SITEARCH}/vgf_lib/binaries") + set(_vgf_site_pure "${Python3_SITELIB}/vgf_lib/binaries") + if(EXISTS "${_vgf_site_arch}") + set(LIBVGF_PATH "${_vgf_site_arch}") + else() + set(LIBVGF_PATH "${_vgf_site_pure}") + endif() + + set(LIBVGF_STATIC "${LIBVGF_PATH}/lib/libvgf.a") + endif() + set(LIBVGF_INCLUDE "${LIBVGF_PATH}/include/") add_library(vgf STATIC IMPORTED) diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index be53c0b2600..6e81adfed6f 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -17,11 +17,7 @@ runtime.python_library( ) runtime.python_library( name = "common", - srcs = [ - "common/__init__.py", - "common/debug.py", - "common/type.py", - ], + srcs = glob(["common/*.py"]), deps = [ "fbsource//third-party/tosa_tools:serializer", "//caffe2:torch", @@ -68,6 +64,7 @@ runtime.python_library( "vgf/__init__.py", "vgf/backend.py", "vgf/compile_spec.py", + "vgf/model_converter.py", "vgf/partitioner.py", ], deps = [ @@ -84,7 +81,6 @@ runtime.python_library( "fbsource//third-party/tosa_tools:tosa", "//executorch/backends/arm/operators:node_visitor", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/exir:lib", ], diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index bb4e992ada1..a75c63fb86e 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -6,7 +6,6 @@ runtime.python_library( deps = [ "//executorch/backends/arm:common", "//executorch/backends/arm:constants", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/tosa/dialect:lib", "//executorch/backends/transforms:fuse_view_copy", diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index bc53606cab6..0b51c28cde8 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -13,26 +13,28 @@ from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa -from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa from .convert_elu_params import ConvertELUParamsPass # noqa from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa -from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa from .convert_minmax_pass import ConvertMinMaxPass # noqa +from .convert_permute_singleton_to_view_pass import ( # noqa + ConvertPermuteSingletonToViewPass, +) from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa -from .convert_to_clamp import ConvertToClampPass # noqa +from .convert_to_clamp_pass import ConvertToClampPass # noqa from .decompose_acosh_pass import DecomposeAcoshPass # noqa from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa from .decompose_addmm_pass import DecomposeAddmmPass # noqa +from .decompose_any_pass import DecomposeAnyPass # noqa from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa from .decompose_asinh_pass import DecomposeAsinhPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_atanh_pass import DecomposeAtanhPass # noqa -from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa +from .decompose_avg_pool2d_pass import DecomposeAvgPool2dPass # noqa from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosh_pass import DecomposeCoshPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa @@ -45,22 +47,24 @@ from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_glu_pass import DecomposeGluPass # noqa -from .decompose_grouped_conv import DecomposeGroupedConv # noqa +from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa from .decompose_int16_activation_conv2d_pass import ( # noqa DecomposeConv2dWithInt16ActivationPass, ) +from .decompose_int_pow_pass import DecomposeIntPowPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa -from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa +from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa from .decompose_logit_pass import DecomposeLogitPass # noqa -from .decompose_masked_fill import DecomposeMaskedFill # noqa -from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa +from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa +from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa from .decompose_remainder_pass import DecomposeRemainderPass # noqa from .decompose_round_pass import DecomposeRoundPass # noqa +from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa from .decompose_select import DecomposeSelectPass # noqa from .decompose_sign_pass import DecomposeSignPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa @@ -73,21 +77,31 @@ from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, - QuantizeOperatorArguments, + QuantizeClampArgumentsPass, +) +from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_constant_ops_pass import ( # noqa + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, ) -from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa -from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa +from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa from .insert_int32_casts_after_int64_placeholders import ( # noqa InsertInt32CastsAfterInt64PlaceholdersPass, ) -from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa +from .insert_rescales_pass import ( # noqa + InsertControlFlowRescalesPass, + InsertRescaleInt32Pass, + InsertRescalePass, +) from .insert_table_ops import InsertTableOpsPass # noqa from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa +from .remove_getitem_pass import RemoveGetItemPass # noqa +from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa from .replace_scalar_with_tensor_pass import ( # noqa ReplaceScalarWithTensorByProfilePass, @@ -100,5 +114,5 @@ from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa -from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip +from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip from .arm_pass_manager import ArmPassManager # noqa # usort: skip diff --git a/backends/arm/_passes/_debug_passes.py b/backends/arm/_passes/_debug_passes.py index e22c8a6cf2c..caaaec8ea5e 100644 --- a/backends/arm/_passes/_debug_passes.py +++ b/backends/arm/_passes/_debug_passes.py @@ -3,6 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import inspect +import os from typing import Set, Type import torch @@ -10,6 +12,7 @@ from executorch.devtools.visualization.visualization_utils import visualize_graph from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule class VisualizePass(ArmPass): @@ -26,3 +29,30 @@ def __init__(self, exported_program: ExportedProgram) -> None: def call(self, graph_module: torch.fx.GraphModule) -> PassResult: visualize_graph(graph_module, self.exported_program) return PassResult(graph_module, False) + + +class PrintGraphModuleCodePass(ArmPass): + """ + This pass prints the graph module's code to stdout for debugging purposes. + + Example output: + + [arm_pass_manager.py:305] + def forward(self, x, y): + x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) + remainder = torch.ops.aten.remainder.Scalar(x, 0.25); x = None + return pytree.tree_unflatten((remainder,), self._out_spec) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, label: str | None = None): + super().__init__() + caller_frame = inspect.stack()[1] + origin = f"{os.path.basename(caller_frame.filename)}:{caller_frame.lineno}" + self.label = f"[{label}]" if label is not None else f"[{origin}]" + + def call(self, graph_module: GraphModule) -> PassResult: + gm_code = graph_module.code.strip() + print(f"\n{self.label}\n{gm_code}") + return PassResult(graph_module, False) diff --git a/backends/arm/_passes/arm_pass.py b/backends/arm/_passes/arm_pass.py index f893eba4fc9..662cd6e8d97 100644 --- a/backends/arm/_passes/arm_pass.py +++ b/backends/arm/_passes/arm_pass.py @@ -6,14 +6,20 @@ import traceback from abc import abstractmethod -from typing import List, Optional, Set, Type +from typing import Any, List, Optional, Set, Type from executorch.exir.pass_base import ExportPass, NodeMetadata +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult class ArmPass(ExportPass): """Base class for Arm passes""" + def __init__(self) -> None: + super().__init__() + self.submodule_depth = 0 + @property @abstractmethod def _passes_required_after(self) -> Set[Type[ExportPass]]: @@ -56,3 +62,19 @@ def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False) old_stack_trace = new_meta.get("stack_trace", "") new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" return super().call_operator(op, args, kwargs, NodeMetadata(new_meta)) + + def call_submodule( + self, graph_module: GraphModule, inputs: tuple[Any, ...] + ) -> PassResult: + self.submodule_depth += 1 + if self.submodule_depth == 1: + result = super().call_submodule(graph_module, inputs) + else: + # When we trace a submodule, we don't want to apply the calling pass. + # Temporarily replace call_operator to avoid this. + _call_operator_fn = self.call_operator + self.call_operator = super().call_operator # type: ignore + result = super().call_submodule(graph_module, inputs) + self.call_operator = _call_operator_fn # type: ignore + self.submodule_depth -= 1 + return result diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index a086d23dc40..75fc529a7e1 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -7,6 +7,7 @@ from collections import defaultdict +from collections.abc import Sequence import executorch.backends.arm.tosa.dialect # noqa: unused from executorch.backends.arm._passes import ( @@ -16,17 +17,16 @@ CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, CastToInt32Pass, - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, Conv1dUnsqueezePass, - ConvertAnyDefaultDimDimsPass, ConvertELUParamsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, ConvertInt64ConstOpsToInt32Pass, ConvertInt64OutputOpsToInt32Pass, - ConvertIntPowToMuls, ConvertMinMaxPass, ConvertMmToBmmPass, + ConvertPermuteSingletonToViewPass, ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, @@ -34,11 +34,12 @@ DecomposeAdaptiveAvgPool2dPass, DecomposeAddmmPass, DecomposeAddSubAlphaPass, + DecomposeAnyPass, DecomposeAsinAndAcosPass, DecomposeAsinhPass, DecomposeAtanhPass, DecomposeAtanPass, - DecomposeAvgPool2d, + DecomposeAvgPool2dPass, DecomposeBatchNormNoStatsPass, DecomposeConv2dWithInt16ActivationPass, DecomposeCoshPass, @@ -52,19 +53,21 @@ DecomposeFloorDividePass, DecomposeGeluPass, DecomposeGluPass, - DecomposeGroupedConv, + DecomposeGroupedConvPass, DecomposeGroupNormPass, + DecomposeIntPowPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, + DecomposeLinalgVectorNormPass, DecomposeLinearPass, - DecomposeLinearVectorNormPass, DecomposeLogitPass, - DecomposeMaskedFill, - DecomposeMaxPool2DPass, + DecomposeMaskedFillPass, + DecomposeMaxPool2dPass, DecomposeMeanDimPass, DecomposeNotEqualPass, DecomposeRemainderPass, DecomposeRoundPass, + DecomposeScaledDotProductAttentionPass, DecomposeSelectPass, DecomposeSignPass, DecomposeSiluPass, @@ -76,20 +79,24 @@ DecomposeVarPass, DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, - FuseBatchnorm2DPass, + FuseBatchNorm2dPass, FuseConstantArgsPass, FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, + FuseViewCopyTransformPass, + InsertControlFlowRescalesPass, InsertInt32CastsAfterInt64PlaceholdersPass, InsertRescaleInt32Pass, InsertRescalePass, InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, - QuantizeOperatorArguments, + QuantizeClampArgumentsPass, + RemoveGetItemPass, + RemoveGraphAssertsPass, RemoveNoopPass, - ReplaceInfValues, + ReplaceInfValuesPass, ReplaceScalarWithTensorByProfilePass, RewriteConv2dPass, RewriteMatmulPass, @@ -106,14 +113,9 @@ TosaLoweringContext, TosaSpecification, ) -from executorch.backends.transforms.decompose_sdpa import ( - DecomposeScaledDotProductAttention, -) -from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram +from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager -from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult from torch.nn.modules import Module @@ -151,6 +153,11 @@ def validate_constraints_mandatory(self): raise RuntimeError(error_msg) + def add_passes(self, passes: Sequence[ExportPass | None]): + for p in passes: + if p is not None: + self.add_pass(p) + def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module @@ -158,96 +165,138 @@ def _transform(self, graph_module: GraphModule): def _tosa_pipeline( self, exported_program: ExportedProgram, graph_module: GraphModule ) -> GraphModule: + # Preprocessing passes self.add_pass(AnnotateOutputDimOrderPass()) - self.add_pass(FuseQuantizedActivationPass()) - self.add_pass(RemoveGetItemPass()) - self.add_pass(ConvertToClampPass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeBatchNormNoStatsPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass( - DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) + + # Node transformation passes (pre q/dq folding) + self.add_passes( + [ + FuseQuantizedActivationPass(), + RemoveGetItemPass(), + ConvertToClampPass(), + DecomposeGroupNormPass(), + DecomposeLayerNormPass(), + DecomposeBatchNormNoStatsPass(), + DecomposeVarPass(), + DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec), + AnnotateDecomposedMatmulPass(), + ConvertELUParamsPass(), + ConvertSplitToSlicePass(), + QuantizeClampArgumentsPass(), + ] + ) + + # Fold Q/DQ nodes, insert INT8/INT32 rescales. + self.add_passes( + [ + FoldAndAnnotateQParamsPass(exported_program), + FuseDuplicateUsersPass(), + # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or + # before FoldAndAnnotateQParamsPass but is unable to at the moment. + # Ticket: MLETORCH-1539 + DecomposeLinearPass(), + InsertRescaleInt32Pass(), + InsertControlFlowRescalesPass(), + ] + ) + + # Node transformation passes (post q/dq folding) + self.add_passes( + [ + DecomposeLogitPass(), + DecomposeMaskedFillPass(), + DecomposeRoundPass(), + DecomposeAcoshPass(), + DecomposeAsinhPass(), + DecomposeCoshPass(), + DecomposeAsinAndAcosPass(), + DecomposeSqrtPass(), + DecomposeAtanPass(), + DecomposeAtanhPass(), + DecomposeAddmmPass(), + DecomposeEluPass(), + DecomposeExpm1Pass(), + DecomposeIntPowPass(), + CastBoolToInt8Pass(), + DecomposeSinhPass(), + DecomposeSignPass(), + DecomposeFloorDividePass(), + DecomposeGeluPass(), + DecomposeAddSubAlphaPass(), + DecomposeGroupedConvPass(), + Conv1dUnsqueezePass(), + ] + ) + + # Scalars -> tensors, match tensor dtypes and ranks. + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(), + ConvertFullLikeToFullPass(), + MatchArgDtypePass(), + UnsqueezeScalarPlaceholdersPass(exported_program), + # TODO: Move DecomposeNotEqualPass to before or after this block of + # passes. Ticket: MLETORCH-1540 + DecomposeNotEqualPass(), + MatchArgRanksPass(exported_program), + FuseConstantArgsPass(exported_program), + ] + ) + + # Node transformation passes (post scalar-removal) + self.add_passes( + [ + DecomposeRemainderPass(), + DecomposeDivTensorModePass(), + DecomposeEmbeddingPass(), + FuseBatchNorm2dPass(exported_program), + ConvertMmToBmmPass(), + DecomposeGluPass(), + DecomposeLeakyReLUPass(), + DecomposeDivPass(), + DecomposeSoftmaxPass(), + ConvertMinMaxPass(), + DecomposeAnyPass(), + DecomposeAdaptiveAvgPool2dPass(), + DecomposeAvgPool2dPass(), + DecorateFp32toInt32CastingPass(), + ComputeConstantOpsAOTPass(exported_program), + ConvertExpandCopyToRepeatPass(), + UnsqueezeBeforeRepeatPass(), + DecomposeCumsumPass(exported_program), + DecomposeMaxPool2dPass(), + SizeAdjustInputPass(), + DecomposeSelectPass(), + ConvertSqueezesToViewPass(), + CastToInt32Pass(), + BroadcastArgsPass(), + ConvertPermuteSingletonToViewPass(), + FuseViewCopyTransformPass(), + DecomposeConv2dWithInt16ActivationPass(), + DecomposeSumPass(), + InsertTableOpsPass(exported_program), + ] + ) + + # Aten -> TOSA transformation passes + self.add_passes( + [ + RewriteUpsamplePass(), + RewriteConv2dPass(exported_program), + RewriteMatmulPass(), + ] + ) + + # Postprocessing/cleanup passes + self.add_passes( + [ + CastInt64BuffersToInt32Pass(exported_program), + FuseEqualPlaceholdersPass(exported_program), + ToTosaMemoryFormatPass(exported_program), + RemoveNoopPass(), + InsertRescalePass(), + ] ) - self.add_pass(AnnotateDecomposedMatmulPass()) - self.add_pass(ConvertELUParamsPass()) - self.add_pass(ConvertSplitToSlicePass()) - self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] - self.add_pass(FuseDuplicateUsersPass()) - self.add_pass(DecomposeExpm1Pass()) - self.add_pass(DecomposeLogitPass()) - self.add_pass(DecomposeMaskedFill()) - self.add_pass(DecomposeRoundPass()) - self.add_pass(DecomposeAcoshPass()) - self.add_pass(DecomposeAsinhPass()) - self.add_pass(DecomposeCoshPass()) - self.add_pass(DecomposeAsinAndAcosPass()) - self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeAtanPass()) - self.add_pass(DecomposeAtanhPass()) - self.add_pass(DecomposeAddmmPass()) - self.add_pass(DecomposeEluPass()) - self.add_pass(DecomposeExpm1Pass()) - self.add_pass(ConvertIntPowToMuls()) - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(DecomposeSinhPass()) - self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeFloorDividePass()) - self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) - self.add_pass(DecomposeRemainderPass()) - self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(FuseBatchnorm2DPass(exported_program)) - self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeLinearPass()) - self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeNotEqualPass()) - self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(DecomposeSoftmaxPass()) - self.add_pass(DecomposeGeluPass()) - self.add_pass(ConvertFullLikeToFullPass()) - self.add_pass(ConvertMinMaxPass()) - self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchArgDtypePass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) - self.add_pass(DecomposeAdaptiveAvgPool2dPass()) - self.add_pass(DecomposeAvgPool2d()) - self.add_pass( - DecorateFp32toInt32CastingPass() - ) # Require that no new fp32->int32 is introduced after this pass - self.add_pass(ComputeConstantOpsAOT(exported_program)) - - self.add_pass(DecomposeGroupedConv()) - self.add_pass(ConvertExpandCopyToRepeatPass()) - self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(DecomposeCumsumPass(exported_program)) - self.add_pass(Conv1dUnsqueezePass()) - self.add_pass(DecomposeMaxPool2DPass()) - self.add_pass(SizeAdjustInputPass()) - self.add_pass(DecomposeSelectPass()) - self.add_pass(ConvertSqueezesToViewPass()) - self.add_pass(CastToInt32Pass()) - self.add_pass(BroadcastArgsPass()) - - self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantArgsPass(exported_program)) - self.add_pass(DecomposeConv2dWithInt16ActivationPass()) - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) - self.add_pass(RewriteUpsamplePass()) - self.add_pass(RewriteConv2dPass(exported_program)) - self.add_pass(RewriteMatmulPass()) - self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - self.add_pass(InsertRescaleInt32Pass()) - self.add_pass(DecomposeSumPass()) - self.add_pass(ToTosaMemoryFormatPass(exported_program)) - self.add_pass(RemoveNoopPass()) - self.add_pass(InsertRescalePass()) self.validate_constraints_mandatory() return self._transform(graph_module) @@ -263,55 +312,73 @@ def transform_to_backend_pipeline( return self._tosa_pipeline(exported_program, graph_module) else: raise NotImplementedError( - f"No pass pipeline implemented for {self.tosa_spec=}" + f"No pass pipeline implemented for {self.tosa_spec}" ) def transform_for_annotation_pipeline(self, graph_module: GraphModule): - self.add_pass( - RemoveGraphAssertsPass() - ) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph - self.add_pass(ConvertInt64ConstOpsToInt32Pass()) - self.add_pass(ConvertInt64OutputOpsToInt32Pass()) - self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass()) - self.add_pass(DecomposeEmbeddingPass()) - self.add_pass(DecomposeScaledDotProductAttention()) - self.add_pass(DecomposeRoundPass()) - self.add_pass(DecomposeLogitPass()) - self.add_pass(CastBoolToInt8Pass()) - self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeAddmmPass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) - self.add_pass(DecomposeRemainderPass()) - self.add_pass(DecomposeFloorDividePass()) - self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(ScalarsToAttributePass()) - self.add_pass(DecomposeGroupNormPass()) - self.add_pass(DecomposeLayerNormPass()) - self.add_pass(DecomposeVarPass()) - self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) - self.add_pass(DecomposeNotEqualPass()) - self.add_pass(DecomposeCosineSimilarityPass()) - self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeLinearVectorNormPass()) - self.add_pass(DecomposeSqrtPass()) - self.add_pass(DecomposeSiluPass()) - self.add_pass(DecomposeAvgPool2d()) - - if self.tosa_spec.is_U55_subset: - # Numerically stable softmax uses amax which is not supported on Ethos-U55 - self.add_pass(DecomposeSoftmaxUnstablePass()) - else: - self.add_pass(DecomposeSoftmaxPass()) + # Preprocessing passes + self.add_pass(RemoveGraphAssertsPass()) - self.add_pass(ConvertMinMaxPass()) - self.add_pass(ReplaceInfValues()) + # Transformation passes (pre scalar -> tensor) + self.add_passes( + [ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + InsertInt32CastsAfterInt64PlaceholdersPass(), + DecomposeEmbeddingPass(), + DecomposeScaledDotProductAttentionPass(), + DecomposeRoundPass(), + DecomposeLogitPass(), + CastBoolToInt8Pass(), + DecomposeSignPass(), + DecomposeAddmmPass(), + DecomposeRemainderPass(), + DecomposeFloorDividePass(), + DecomposeDivTensorModePass(), + ] + ) - if not self.tosa_spec.is_U55_subset: - # Uses where which is not supported on Ethos-U55 - self.add_pass(DecomposeMaskedFill()) + # Scalars -> tensors + self.add_passes( + [ + ReplaceScalarWithTensorByProfilePass(), + ScalarsToAttributePass(), + ] + ) + + # Transformation passes (post scalar removal) + self.add_passes( + [ + DecomposeAddSubAlphaPass(), + DecomposeGroupNormPass(), + DecomposeLayerNormPass(), + DecomposeVarPass(), + DecomposeMeanDimPass(graph_module, self.tosa_spec), + DecomposeNotEqualPass(), + DecomposeCosineSimilarityPass(), + DecomposeGluPass(), + DecomposeDivPass(), + DecomposeLeakyReLUPass(), + DecomposeLinalgVectorNormPass(), + DecomposeSqrtPass(), + DecomposeSiluPass(), + DecomposeAvgPool2dPass(), + ( + DecomposeSoftmaxUnstablePass() + if self.tosa_spec.is_U55_subset + else DecomposeSoftmaxPass() + ), + ConvertMinMaxPass(), + ] + ) + + # Postprocessing passes + self.add_passes( + [ + ReplaceInfValuesPass(), + DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None, + ] + ) return self._transform(graph_module) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index de42c961d08..b9aa04236eb 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -31,11 +31,25 @@ from torch.export.graph_signature import InputKind +def is_submodule_node(node: torch.fx.Node): + if node.op not in ("get_attr", "placeholder"): + return False + try: + node.graph.owning_module.get_submodule(node.target) + except AttributeError: + return False + return True + + def is_get_attr_node(node: torch.fx.Node) -> bool: """ - Returns true if the given node is a get attr node for a tensor of the model + Returns true if the given node is a get attr node for a tensor of the model. """ - return isinstance(node, torch.fx.Node) and node.op == "get_attr" + return ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and not is_submodule_node(node) + ) def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: @@ -100,6 +114,7 @@ def create_node( quantize: bool = False, q_params: Optional[tuple] = None, from_node: Optional[torch.fx.Node] = None, + inherit_qparams: bool = False, ): """ Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node. @@ -118,6 +133,14 @@ def create_node( keys = from_node.meta.keys() for key in keys: new_meta[key] = from_node.meta[key] + if not inherit_qparams: + if "input_qparams" in new_meta: + new_meta["input_qparams"] = {} + if "output_qparams" in new_meta: + new_meta["output_qparams"] = {} + elif inherit_qparams: + raise ValueError("inherit_qparams is only valid when from_node is given") + old_stack_trace = new_meta.get("stack_trace", "") new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}" node.meta = new_meta @@ -202,7 +225,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None): f"Out of bounds index {key} for getting value in args (of size {len(args)})" ) elif isinstance(key, str): - return args.get(key, default_value) # type: ignore[union-attr] # pyre-ignore[16] + return args.get(key, default_value) # type: ignore[union-attr] elif isclass(key): for arg in args: if isinstance(arg, key): diff --git a/backends/arm/_passes/broadcast_args_pass.py b/backends/arm/_passes/broadcast_args_pass.py index 131b749b702..d11fb779280 100644 --- a/backends/arm/_passes/broadcast_args_pass.py +++ b/backends/arm/_passes/broadcast_args_pass.py @@ -63,6 +63,7 @@ def call(self, graph_module: GraphModule) -> PassResult: args=(arg, multiples), kwargs={}, from_node=node, + inherit_qparams=False, ) node.replace_input_with(arg, repeat) diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 4822c6c25c0..02a9cbeceaf 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if len(node.users) == 0: continue + if "val" not in node.meta: + continue fake_tensor = node.meta["val"] if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): continue diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index a368f1b65ed..b6cf8ffa41b 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -40,13 +40,17 @@ def call_operator(self, op, args, kwargs, meta): if len(stride) != 1: return super().call_operator(op, args, kwargs, meta) + x_meta = meta.copy() + x_meta.data["input_qparams"] = {} + x_meta.data["output_qparams"] = {} + x = args[0] x_unsqueezed_shape = list(x.data.shape) + [1] x = super().call_operator( exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, - meta, + x_meta, updated=True, ) @@ -79,12 +83,15 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.convolution.default, new_args, kwargs, meta, updated=True ) + x_squeezed_meta = meta.copy() + x_squeezed_meta.data["input_qparams"] = {} + x_squeezed_meta.data["output_qparams"] = {} x_squeezed_shape = list(x.data.shape)[:-1] x = super().call_operator( exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, - meta, + x_squeezed_meta, updated=True, ) diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index f932ae7f4c4..0cd306086cb 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -20,6 +20,7 @@ def calculate_multiples(args): + """Returns expand args converted to repeat args, and whether the expand changes the rank""" input_node_or_tensor = args[0] if isinstance(input_node_or_tensor, torch.fx.node.Node): @@ -45,7 +46,7 @@ def calculate_multiples(args): multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 for i in range(expanded_rank) ] - return multiples + return multiples, expanded_rank != len(input_shape) class ConvertExpandCopyToRepeatPass(ArmPass): @@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta): if op != self.expand_copy: return super().call_operator(op, args, kwargs, meta) - multiples = calculate_multiples(args) + multiples, changes_rank = calculate_multiples(args) - if all((x == 1 for x in multiples)): + if all((x == 1 for x in multiples)) and not changes_rank: # All dimensions/repetitions occur only once. Remove node # altogether since it's in practice just a copy. logger.warning("Found redundant expand node (no-op). Removing it.") diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index 06822a4abcf..becb0b7f971 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -6,7 +6,9 @@ from typing import Set, Type from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -24,7 +26,7 @@ class ConvertFullLikeToFullPass(ArmPass): Skip layout and device since it's not relevant for our backend. """ - _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} def call_operator(self, op, args, kwargs, meta): if op not in [ diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py index dff270fda13..85fcf715f07 100644 --- a/backends/arm/_passes/convert_int64_const_ops_to_int32.py +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -9,7 +9,9 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.pass_base import ExportPass, PassResult @@ -30,7 +32,7 @@ class ConvertInt64ConstOpsToInt32Pass(ArmPass): 5. `torch.tensor` """ - _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} torch_ops = [ torch.ops.aten.full.default, @@ -47,7 +49,10 @@ def call(self, graph_module: torch.fx.GraphModule): if node.op != "call_function": continue - if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops: + if ( + node.target + not in ComputeConstantOpsAOTPass.targeted_ops + self.torch_ops + ): continue data = node.target(*node.args, **node.kwargs) diff --git a/backends/arm/_passes/convert_permute_singleton_to_view_pass.py b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py new file mode 100644 index 00000000000..fe4697bc213 --- /dev/null +++ b/backends/arm/_passes/convert_permute_singleton_to_view_pass.py @@ -0,0 +1,64 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Sequence, Set, Tuple, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch._ops import OpOverload + + +_PERMUTE_TARGETS: Tuple[OpOverload, ...] = ( + exir_ops.edge.aten.permute.default, + exir_ops.edge.aten.permute_copy.default, +) + + +class ConvertPermuteSingletonToViewPass(ArmPass): + """Replace permutations that only move singleton axes with a reshape. + + Examples: + x = rand(1,1,1,4) + y = permute(x, (0,3,1,2)) + + becomes: + x = rand(1,1,1,4) + y = view_copy(x, (1,4,1,1)) + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta): + if op not in _PERMUTE_TARGETS: + return super().call_operator(op, args, kwargs, meta) + + input_tensor = args[0].data + permutation = args[1] + if not is_singleton_permutation(input_tensor.shape, permutation): + return super().call_operator(op, args, kwargs, meta) + + output_shape = meta["val"].shape + view_args = (args[0], output_shape) + return super().call_operator( + exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta + ) + + +def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool: + """ + Treat as a view only when non-singleton axes keep their order; singleton + axes may move freely since they carry no data volume. + """ + rank = len(shape) + normalized_perm = [d % rank for d in permutation] + + non_singleton_axes = [i for i, size in enumerate(shape) if size != 1] + permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1] + + return permuted_non_singleton_axes == non_singleton_axes diff --git a/backends/arm/_passes/convert_squeezes_to_view.py b/backends/arm/_passes/convert_squeezes_to_view.py index f7b9df3b5f4..9d185a8e08c 100644 --- a/backends/arm/_passes/convert_squeezes_to_view.py +++ b/backends/arm/_passes/convert_squeezes_to_view.py @@ -1,5 +1,4 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,9 +7,9 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass - -from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform - +from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( + FuseViewCopyTransformPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -20,7 +19,7 @@ class ConvertSqueezesToViewPass(ArmPass): Replaces squeeze/unsqueeze operators with view. These are simply special cases of the view op, so removing them gives us less cases to handle in the node visitiors. """ - _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform} + _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} def call_operator(self, op, args, kwargs, meta): if op not in [ diff --git a/backends/arm/_passes/convert_to_clamp.py b/backends/arm/_passes/convert_to_clamp_pass.py similarity index 91% rename from backends/arm/_passes/convert_to_clamp.py rename to backends/arm/_passes/convert_to_clamp_pass.py index 1ada1efe69b..4b28f993acd 100644 --- a/backends/arm/_passes/convert_to_clamp.py +++ b/backends/arm/_passes/convert_to_clamp_pass.py @@ -8,7 +8,7 @@ from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - QuantizeOperatorArguments, + QuantizeClampArgumentsPass, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -30,7 +30,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]: class ConvertToClampPass(ArmPass): - _passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments} + _passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass} def call_operator(self, op, args, kwargs, meta): if op not in edge_operators: diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index 52ddb77151d..5905e8f4496 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -9,10 +9,12 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d +from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( + DecomposeAvgPool2dPass, +) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, NodeMetadata edge_ops = (exir_ops.edge.aten._adaptive_avg_pool2d.default,) aten_ops = (torch.ops.aten.adaptive_avg_pool2d.default,) @@ -44,7 +46,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass): The output is of size output_size_h x output_size_w for any input. """ - _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2d} + _passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass} def call_operator(self, op, args, kwargs, meta, updated=False): if op not in (edge_ops + aten_ops): @@ -60,6 +62,11 @@ def call_operator(self, op, args, kwargs, meta, updated=False): # Vela currently only allows a stride in the interval of [1,3] for AvgPool2d. # To accommodate this, the AvgPool2d op is applied to pooling regions and the results are concatenated. + # Slices and concats does not require quantization parameters + metadata_dict = dict(meta.data) + metadata_dict["input_qparams"] = {} + metadata_dict["output_qparams"] = {} + meta_with_no_qparams = NodeMetadata(metadata_dict) res = [] for out_i in range(output_size_h): row = [] @@ -72,11 +79,15 @@ def call_operator(self, op, args, kwargs, meta, updated=False): # Slice along H x_h = super().call_operator( - slice_op, (x, 2, start_h, end_h), kwargs, meta, True + slice_op, (x, 2, start_h, end_h), kwargs, meta_with_no_qparams, True ) # Slice along W x_hw = super().call_operator( - slice_op, (x_h, 3, start_w, end_w), kwargs, meta, True + slice_op, + (x_h, 3, start_w, end_w), + kwargs, + meta_with_no_qparams, + True, ) # Apply avg pooling with kernel size equal to the pooling region @@ -89,9 +100,13 @@ def call_operator(self, op, args, kwargs, meta, updated=False): row.append(pooled) # Concatenate row results along width (dim=3) - row_tensor = super().call_operator(cat_op, (row, 3), kwargs, meta, True) + row_tensor = super().call_operator( + cat_op, (row, 3), kwargs, meta_with_no_qparams, True + ) res.append(row_tensor) # Concatenate all rows along height (dim=2) - out = super().call_operator(cat_op, (res, 2), kwargs, meta, True) + out = super().call_operator( + cat_op, (res, 2), kwargs, meta_with_no_qparams, True + ) return out diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/decompose_any_pass.py similarity index 70% rename from backends/arm/_passes/convert_any_default_dim_dims_pass.py rename to backends/arm/_passes/decompose_any_pass.py index d09cd22cbd4..a0487e7e139 100644 --- a/backends/arm/_passes/convert_any_default_dim_dims_pass.py +++ b/backends/arm/_passes/decompose_any_pass.py @@ -7,9 +7,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_squeezes_to_view import ( - ConvertSqueezesToViewPass, -) +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] ops as exir_ops, ) @@ -19,27 +17,29 @@ ) -class ConvertAnyDefaultDimDimsPass(ArmPass): +class DecomposeAnyPass(ArmPass): """ - Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction. - Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion. + Converts any.default, any.dim and any.dims to a sequence of any.dim by + unrolling multi-dimensional reductions with keepdim=True. If keepdim=False + was requested, the final shape adjustment is implemented with a + view_copy.default to the reduced shape. Example 1 Original: - any() # x.shape: [dim1, dim2, ..., dimn] + any.dim() # x.shape: [dim1, dim2, ..., dimn] After pass: any.dim(dim1, keepdim = True) any.dim(dim2, keepdim = True) ... any.dim(dimn, keepdim = True) - squeeze(dim = [dim1, dim2, ...., dimn]) + view_copy(shape = squeezed_shape) Example 2 Original: any.dim(dim1, keepdim = False) After pass: any.dim(dim1, keepdim = True) - squeeze(dim = [dim1]) + view_copy(shape = squeezed_shape) Example 3 Original: @@ -47,10 +47,10 @@ class ConvertAnyDefaultDimDimsPass(ArmPass): After pass: any.dim(dim1, keepdim = True) any.dim(dim2, keepdim = True) - squeeze(dim = [dim1, dim2]) + view_copy(shape = squeezed_shape) """ - _passes_required_after: Set[Type[ExportPass]] = {ConvertSqueezesToViewPass} + _passes_required_after: Set[Type[ExportPass]] = set() def call(self, graph_module: torch.fx.GraphModule): modified = False @@ -67,40 +67,40 @@ def call(self, graph_module: torch.fx.GraphModule): if len(node.args) == 1: # any.default(input) input_node = (node.args)[0] - dims = range(len(input_node.meta["val"].shape)) + dims_to_reduce = range(len(input_node.meta["val"].shape)) keepdim = False elif len(node.args) == 2: # any.dim/dims(input, dims=dims) - input_node, dims = node.args + input_node, dims_to_reduce = node.args keepdim = False elif len(node.args) == 3: # any.dim/dims(input, dims=dims, keepdim=keepdim) - input_node, dims, keepdim = node.args + input_node, dims_to_reduce, keepdim = node.args else: raise RuntimeError( f"Unexpected arg size {len(node.args)} in {node.name}" ) try: - iter(dims) + iter(dims_to_reduce) except: - dims = [dims] # type: ignore[assignment] + dims_to_reduce = [dims_to_reduce] # type: ignore[assignment] else: - dims = list(dims) # type: ignore[assignment] + dims_to_reduce = list(dims_to_reduce) # type: ignore[assignment] # Unroll multi-dimensional reduction and keep-dims arg with graph_module.graph.inserting_before(node): - for dim in dims: + for dim in dims_to_reduce: args = (input_node, dim, True) input_node = graph_module.graph.create_node( "call_function", exir_ops.edge.aten.any.dim, args, node.kwargs ) if not keepdim: - args = (input_node, dims) # type: ignore[assignment] + output_shape = list(get_first_fake_tensor(node).shape) input_node = graph_module.graph.create_node( "call_function", - exir_ops.edge.aten.squeeze_copy.dims, - args, + exir_ops.edge.aten.view_copy.default, + (input_node, output_shape), ) node.replace_all_uses_with(input_node) diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py similarity index 98% rename from backends/arm/_passes/decompose_avg_pool2d.py rename to backends/arm/_passes/decompose_avg_pool2d_pass.py index 0187ee45a1e..14b03cf6243 100644 --- a/backends/arm/_passes/decompose_avg_pool2d.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -8,7 +8,9 @@ import torch from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) @@ -37,8 +39,8 @@ def get_decomposition(op) -> tuple: raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}") -class DecomposeAvgPool2d(ArmPass): - _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT} +class DecomposeAvgPool2dPass(ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass} def call_operator(self, op, args, kwargs, meta): if op not in (edge_div_ops + aten_div_ops): diff --git a/backends/arm/_passes/decompose_batch_norm_no_stats.py b/backends/arm/_passes/decompose_batch_norm_no_stats.py index ef9b9f859cd..9a486376617 100644 --- a/backends/arm/_passes/decompose_batch_norm_no_stats.py +++ b/backends/arm/_passes/decompose_batch_norm_no_stats.py @@ -10,7 +10,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops @@ -37,7 +39,7 @@ class DecomposeBatchNormNoStatsPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, InsertTableOpsPass, } diff --git a/backends/arm/_passes/decompose_cumsum_pass.py b/backends/arm/_passes/decompose_cumsum_pass.py index 7066fdb16eb..dedbc2c039f 100644 --- a/backends/arm/_passes/decompose_cumsum_pass.py +++ b/backends/arm/_passes/decompose_cumsum_pass.py @@ -101,7 +101,13 @@ def call(self, graph_module): with graph.inserting_before(node): # Reshape to 4D with view_args = (input_node, conv_shape) - view_node = create_node(graph, view_op, args=view_args, from_node=node) + view_node = create_node( + graph, + view_op, + args=view_args, + from_node=node, + inherit_qparams=False, + ) conv_args = ( view_node, @@ -114,7 +120,9 @@ def call(self, graph_module): [0], 1, ) - conv_node = create_node(graph, conv_op, args=conv_args, from_node=node) + conv_node = create_node( + graph, conv_op, args=conv_args, from_node=node, inherit_qparams=True + ) # The convolution is inserted after quantization, so we need to set our # own quantization parameters for the weights here. However since the @@ -129,12 +137,20 @@ def call(self, graph_module): slice_args = (conv_node, 2, 0, original_shape[dim]) slice_node = create_node( - graph, slice_op, args=slice_args, from_node=node + graph, + slice_op, + args=slice_args, + from_node=node, + inherit_qparams=False, ) view_original_args = (slice_node, original_shape) view_original_node = create_node( - graph, view_op, args=view_original_args, from_node=node + graph, + view_op, + args=view_original_args, + from_node=node, + inherit_qparams=False, ) # Replace and remove original diff --git a/backends/arm/_passes/decompose_embedding_pass.py b/backends/arm/_passes/decompose_embedding_pass.py index a87b26366d7..e9c8f303cbf 100644 --- a/backends/arm/_passes/decompose_embedding_pass.py +++ b/backends/arm/_passes/decompose_embedding_pass.py @@ -10,14 +10,15 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform +from executorch.backends.arm._passes.fuse_view_copy_transform_pass import ( + FuseViewCopyTransformPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from .arm_pass_utils import create_node, get_first_fake_tensor logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) class DecomposeEmbeddingPass(ArmPass): @@ -34,7 +35,7 @@ class DecomposeEmbeddingPass(ArmPass): i = indices is expected to be int32 before this pass """ - _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransform} + _passes_required_after: Set[Type[ExportPass]] = {FuseViewCopyTransformPass} aten_ops = (torch.ops.aten.embedding.default,) edge_ops = (exir_ops.edge.aten.embedding.default,) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 09a891c34dc..d2eb908e925 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -6,8 +6,8 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_int_pow_to_mul import ConvertIntPowToMuls from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm._passes.decompose_int_pow_pass import DecomposeIntPowPass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -80,7 +80,7 @@ class DecomposeExpm1Pass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ConvertIntPowToMuls, + DecomposeIntPowPass, InsertTableOpsPass, DecomposeDivPass, ReplaceScalarWithTensorByProfilePass, diff --git a/backends/arm/_passes/decompose_gelu_pass.py b/backends/arm/_passes/decompose_gelu_pass.py index 2a25e6dbb6d..5bf39370835 100644 --- a/backends/arm/_passes/decompose_gelu_pass.py +++ b/backends/arm/_passes/decompose_gelu_pass.py @@ -8,7 +8,9 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass @@ -85,7 +87,7 @@ class DecomposeGeluPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv_pass.py similarity index 95% rename from backends/arm/_passes/decompose_grouped_conv.py rename to backends/arm/_passes/decompose_grouped_conv_pass.py index 11d9f605127..a0765b865fc 100644 --- a/backends/arm/_passes/decompose_grouped_conv.py +++ b/backends/arm/_passes/decompose_grouped_conv_pass.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeGroupedConv(ArmPass): +class DecomposeGroupedConvPass(ArmPass): """ Splits a grouped convolution which is not supported by TOSA into multiple convolutions using slice->conv->cat. @@ -81,7 +81,7 @@ def _get_meta_copy(meta, i, output_slice_size): new_qparams = meta.data.get("input_qparams").copy() # Get quantization params of the weights and slice them. qarg = new_qparams[1] - new_qparams[1] = DecomposeGroupedConv._split_per_channel_qparams( + new_qparams[1] = DecomposeGroupedConvPass._split_per_channel_qparams( qarg, index=i, output_slice_size=output_slice_size ) @@ -117,7 +117,7 @@ def call_operator(self, op, args, kwargs, meta): no_q_dq_meta.data = {} no_q_dq_meta.data = {} - slice_op, conv_op, cat_op = DecomposeGroupedConv._get_decomposition(op) + slice_op, conv_op, cat_op = DecomposeGroupedConvPass._get_decomposition(op) input_slices = [] for i in range(groups): @@ -163,7 +163,9 @@ def call_operator(self, op, args, kwargs, meta): zip(input_slices, filter_slices, bias_slices) ): - meta_copy = DecomposeGroupedConv._get_meta_copy(meta, i, output_slice_size) + meta_copy = DecomposeGroupedConvPass._get_meta_copy( + meta, i, output_slice_size + ) if op == exir_ops.edge.aten.convolution.default: conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index ac4f271b744..2f160474c5b 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -4,17 +4,18 @@ # LICENSE file in the root directory of this source tree. -from typing import cast +from typing import cast, Set, Type import torch +from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00 +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -class DecomposeConv2dWithInt16ActivationPass(ExportPass): +class DecomposeConv2dWithInt16ActivationPass(ArmPass): """ This pass decomposes a convolution with input dtype int16 and bias into a convolution without bias followed by an addition of the bias @@ -22,6 +23,8 @@ class DecomposeConv2dWithInt16ActivationPass(ExportPass): in torch. Instead rescale the int48 output to int16 and add the bias in int16. """ + _passes_required_after: Set[Type[ExportPass]] = set() + def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.convolution.default: return super().call_operator(op, args, kwargs, meta) @@ -37,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta): if args[0].data.dtype == torch.int8: return super().call_operator(op, args, kwargs, meta) elif args[0].data.dtype == torch.int16: - if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension( - "int16" - ): + if not tosa_spec.support_extension("int16"): raise ValueError( "int16 activation for convolution requires TOSA int16 extension" ) diff --git a/backends/arm/_passes/convert_int_pow_to_mul.py b/backends/arm/_passes/decompose_int_pow_pass.py similarity index 98% rename from backends/arm/_passes/convert_int_pow_to_mul.py rename to backends/arm/_passes/decompose_int_pow_pass.py index e2c8bd0c4d6..4db5e45c120 100644 --- a/backends/arm/_passes/convert_int_pow_to_mul.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -11,7 +11,7 @@ from executorch.exir.pass_base import ExportPass -class ConvertIntPowToMuls(ArmPass): +class DecomposeIntPowPass(ArmPass): """ Replaces pow with integer exponent with a series of multiplications. Only handles pow.Tensor_Scalar and not pow.Tensor_Tensor. diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 7623e410cf9..5f56de92512 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -12,7 +12,9 @@ from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -61,7 +63,7 @@ class DecomposeLayerNormPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, DecomposeMeanDimPass, DecomposeVarPass, InsertTableOpsPass, diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 5c6c8fc0ec5..83bbc6669ef 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -12,7 +12,7 @@ from executorch.exir.pass_base import ExportPass -class DecomposeLinearVectorNormPass(ArmPass): +class DecomposeLinalgVectorNormPass(ArmPass): """ This pass decomposes aten.linalg_vector_norm.default into more primitive ops. We need to add this pass before quantization for graph annotation. diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index ffe63f8cb65..e1a9cfd0bfc 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -12,6 +12,7 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm._passes.insert_rescales_pass import InsertRescaleInt32Pass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -26,7 +27,7 @@ class DecomposeLinearPass(ArmPass): output = view(conv2d) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass} def call(self, graph_module): for node in graph_module.graph.nodes: @@ -54,6 +55,8 @@ def call(self, graph_module): op_target=exir_ops.edge.aten.view_copy.default, args=(input, input_reshaped_shape), kwargs={}, + from_node=node, + inherit_qparams=False, ) # Reshape weights to 4D with shape (Co, Ci, 1, 1) @@ -62,6 +65,8 @@ def call(self, graph_module): op_target=exir_ops.edge.aten.view_copy.default, args=(weights, weights_reshaped_shape), kwargs={}, + from_node=node, + inherit_qparams=False, ) conv = create_node( @@ -80,6 +85,7 @@ def call(self, graph_module): ), kwargs={}, from_node=node, + inherit_qparams=True, ) with graph_module.graph.inserting_after(conv): @@ -92,14 +98,8 @@ def call(self, graph_module): args=(conv, list(output_shape)), kwargs={}, from_node=node, + inherit_qparams=False, ) - # Quantization parameters are inherited from original linear node, but - # output reshape should use the linear node's output qparams for both input - # and output. - if "input_qparams" in output.meta: - output.meta["input_qparams"] = output.meta.get( - "output_qparams", None - ) node.replace_all_uses_with(output) graph_module.graph.erase_node(node) diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill_pass.py similarity index 97% rename from backends/arm/_passes/decompose_masked_fill.py rename to backends/arm/_passes/decompose_masked_fill_pass.py index 5a0f12348ec..09a3492a0c6 100644 --- a/backends/arm/_passes/decompose_masked_fill.py +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -34,7 +34,7 @@ def _get_decomposition(op) -> tuple: raise RuntimeError(f"Unable to get decomposition for op {op}") -class DecomposeMaskedFill(ArmPass): +class DecomposeMaskedFillPass(ArmPass): """ Masked fill takes in a boolean mask, a tensor and a scalar value. Fills the tensor with the scalar value according to the boolean mask. diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py similarity index 88% rename from backends/arm/_passes/decompose_maxpool2d_with_dilation.py rename to backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py index 9e98ad90aed..bf3f6afc418 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation_pass.py @@ -19,7 +19,7 @@ ) -class DecomposeMaxPool2DPass(ArmPass): +class DecomposeMaxPool2dPass(ArmPass): """ Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. """ @@ -70,6 +70,12 @@ def call_operator(self, op, args, kwargs, meta): ph2 += extra_h * d_h pw2 += extra_w * d_w + meta_with_no_qparams = meta.copy() + meta_with_no_qparams.data["output_qparams"] = {} + meta_with_no_qparams.data["input_qparams"] = {} + meta_with_no_output_qparams = meta.copy() + meta_with_no_output_qparams.data["output_qparams"] = {} + # 1) Pad via EXIR edge pad (preserves dtype) pad_edge = exir_ops.edge.aten.constant_pad_nd.default pads = [pw, pw2, ph, ph2, 0, 0, 0, 0] @@ -77,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta): pad_edge, (x, pads, 0), {}, - meta, + meta_with_no_output_qparams, ) # 2) Space-to-batch: reshape and permute @@ -85,19 +91,19 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.view_copy.default, (x_pad, [N, C, H_pack, d_h, W_pack, d_w]), {}, - meta, + meta_with_no_qparams, ) x2 = super().call_operator( exir_ops.edge.aten.permute_copy.default, (x2, [3, 5, 0, 1, 2, 4]), {}, - meta, + meta_with_no_qparams, ) x2 = super().call_operator( exir_ops.edge.aten.view_copy.default, (x2, [N * d_h * d_w, C, H_pack, W_pack]), {}, - meta, + meta_with_no_qparams, ) # 3) Core pooling on packed tensor @@ -120,13 +126,13 @@ def call_operator(self, op, args, kwargs, meta): operator.getitem, (pool_out, 0), {}, - meta, + meta_with_no_qparams, ) indices_proxy = super().call_operator( operator.getitem, (pool_out, 1), {}, - meta, + meta_with_no_qparams, ) pooled_fake, _ = pool_out.data else: @@ -141,20 +147,20 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.view_copy.default, (pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]), {}, - meta, + meta_with_no_qparams, ) out = super().call_operator( exir_ops.edge.aten.permute_copy.default, (out, [2, 3, 4, 0, 5, 1]), {}, - meta, + meta_with_no_qparams, ) # now flatten back into (N, C, H_out*d_h, W_out*d_w) out = super().call_operator( exir_ops.edge.aten.view_copy.default, (out, [N, C_out, H_out * d_h, W_out * d_w]), {}, - meta, + meta_with_no_qparams, ) # 5) Final crop @@ -166,13 +172,13 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.slice_copy.Tensor, (out, 2, S_top, S_top + H), {}, - meta, + meta_with_no_qparams, ) out = super().call_operator( exir_ops.edge.aten.slice_copy.Tensor, (out, 3, S_left, S_left + W), {}, - meta, + meta_with_no_qparams, ) if is_with_indices: @@ -181,7 +187,7 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.view_copy.default, (indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]), {}, - meta, + meta_with_no_qparams, ) idx = super().call_operator( exir_ops.edge.aten.permute_copy.default, @@ -193,19 +199,19 @@ def call_operator(self, op, args, kwargs, meta): exir_ops.edge.aten.view_copy.default, (idx, [N, C_out, H_out * d_h, W_out * d_w]), {}, - meta, + meta_with_no_qparams, ) idx = super().call_operator( exir_ops.edge.aten.slice_copy.Tensor, (idx, 2, S_top, S_top + H), {}, - meta, + meta_with_no_qparams, ) idx = super().call_operator( exir_ops.edge.aten.slice_copy.Tensor, (idx, 3, S_left, S_left + W), {}, - meta, + meta_with_no_qparams, ) return out, idx diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 1360fc44f98..9bff06b4dfe 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -11,7 +11,9 @@ from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.backend.utils import WhyNoPartitionReporter @@ -78,7 +80,7 @@ class DecomposeMeanDimPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, DecomposeSumPass, SizeAdjustInputPass, } diff --git a/backends/arm/_passes/decompose_remainder_pass.py b/backends/arm/_passes/decompose_remainder_pass.py index ac37eae86df..6c11a7b600e 100644 --- a/backends/arm/_passes/decompose_remainder_pass.py +++ b/backends/arm/_passes/decompose_remainder_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Set, Type +from typing import Dict, Set, Type import torch from executorch.backends.arm._passes import ArmPass @@ -17,46 +17,50 @@ Op = OpOverload | EdgeOpOverload - -def _get_remainder_decomposition_ops(op: Op) -> tuple[Op, Op, Op]: - """ - Returns the (div_mode_op, mul_op, sub_op) needed to lower the provided - remainder operator. The concrete ops depend on whether the remainder op is - the aten or edge variant. - """ - if op == exir_ops.edge.aten.remainder.Tensor: - return ( - exir_ops.edge.aten.div.Tensor_mode, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.sub.Tensor, - ) - if op == torch.ops.aten.remainder.Tensor: - return ( - torch.ops.aten.div.Tensor_mode, - torch.ops.aten.mul.Tensor, - torch.ops.aten.sub.Tensor, - ) - raise RuntimeError(f"Can't get remainder decomposition ops for op {op}") +_decomposition_ops: Dict[Op, tuple[Op, Op, Op]] = { + exir_ops.edge.aten.remainder.Scalar: ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.sub.Tensor, + ), + torch.ops.aten.remainder.Tensor: ( + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.mul.Tensor, + torch.ops.aten.sub.Tensor, + ), + torch.ops.aten.remainder.Scalar: ( + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.mul.Scalar, + torch.ops.aten.sub.Tensor, + ), + exir_ops.edge.aten.remainder.Tensor: ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + ), +} class DecomposeRemainderPass(ArmPass): """ Decompose the remainder operation into primitive arithmetic: remainder(x, y) -> x - floor_div(x, y) * y - where floor_div(x, y) == div(x, y, rounding_mode=\"floor\"). + where floor_div(x, y) == div(x, y, rounding_mode="floor"). """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} def call_operator(self, op, args, kwargs, meta, updated=False): supported_ops = ( + exir_ops.edge.aten.remainder.Scalar, exir_ops.edge.aten.remainder.Tensor, + torch.ops.aten.remainder.Scalar, torch.ops.aten.remainder.Tensor, ) if op not in supported_ops: return super().call_operator(op, args, kwargs, meta, updated) - div_op, mul_op, sub_op = _get_remainder_decomposition_ops(op) + div_op, mul_op, sub_op = _decomposition_ops[op] x, y = args[0], args[1] floor_div = super().call_operator( diff --git a/backends/arm/_passes/decompose_sdpa_pass.py b/backends/arm/_passes/decompose_sdpa_pass.py new file mode 100644 index 00000000000..566b43d5aa3 --- /dev/null +++ b/backends/arm/_passes/decompose_sdpa_pass.py @@ -0,0 +1,16 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.transforms import decompose_sdpa +from executorch.exir.pass_base import ExportPass + + +class DecomposeScaledDotProductAttentionPass( + ArmPass, decompose_sdpa.DecomposeScaledDotProductAttention +): + _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index ba12f9d93d7..23b100ca41b 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -52,10 +52,18 @@ def call(self, graph_module: torch.fx.GraphModule): with graph_module.graph.inserting_before(node): slice_node = create_node( - graph_module.graph, slice_op, (input_node, dim, index, index + 1) + graph_module.graph, + slice_op, + (input_node, dim, index, index + 1), + from_node=node, + inherit_qparams=False, ) squeeze_node = create_node( - graph_module.graph, squeeze_op, (slice_node, [dim]), from_node=node + graph_module.graph, + squeeze_op, + (slice_node, [dim]), + from_node=node, + inherit_qparams=True, ) node.replace_all_uses_with(squeeze_node) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index d96616a6373..0e63ef38669 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -77,7 +77,11 @@ def call_operator(self, op, args, kwargs, meta): for dim in dims: input_node = super().call_operator( - sum_op, (input_node, dim, True), kwargs, meta, updated=True + sum_op, + (input_node, dim, True), + kwargs, + meta, + updated=True, ) if not keepdims: diff --git a/backends/arm/_passes/decompose_var_pass.py b/backends/arm/_passes/decompose_var_pass.py index f5903d61135..bb2e2066a06 100644 --- a/backends/arm/_passes/decompose_var_pass.py +++ b/backends/arm/_passes/decompose_var_pass.py @@ -12,7 +12,9 @@ from executorch.backends.arm._passes.arm_pass_utils import get_node_arg from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -52,7 +54,7 @@ class DecomposeVarPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, DecomposeMeanDimPass, DecomposeSumPass, } diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 2a0e889f87c..8815a47b18c 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -9,6 +9,7 @@ from typing import cast, Optional, Set, Type +import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( get_param_tensor, @@ -152,6 +153,83 @@ def fold_and_annotate_arg( if len(n.users) == 0: graph_module.graph.erase_node(n) + def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): + """Fold outmost quant nodes inside submodule. + placeholders => qs => dqs => ... => qs => dqs => output + becomes + placeholders => dqs => ... => qs => output, + With output_qparams meta in the placeholders, and input_qparams meta in the output node. + """ + match node.target: + case torch.ops.higher_order.cond: + submodule_nodes = cast(list[Node], node.args[1:3]) + args = cast(list[Node], node.args[-1]) + case torch.ops.higher_order.while_loop: + submodule_nodes = cast(list[Node], node.args[0:2]) + args = cast(list[Node], node.args[-2]) + case _: + raise ValueError(f"Unhandled target {node.target}") + submodules = ( + graph_module.get_submodule(str(submodule_node.target)) + for submodule_node in submodule_nodes + ) + for submodule in submodules: + submodule = cast(GraphModule, submodule) + output_node = submodule.graph.output_node() + output_node.meta["input_qparams"] = {} + nodes_to_remove = [] + arg_id = 0 + for submodule_node in submodule.graph.nodes: + # Remove initial q nodes and ending dq nodes in the module. + submodule_node = cast(Node, submodule_node) + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + input_node = cast(Node, submodule_node.args[0]) + input_node.meta["val"] = submodule_node.meta["val"] + quant_args = QuantArgs.from_operator( + submodule_node.target, submodule_node.args + ) + input_node.meta["output_qparams"] = {0: quant_args} + + submodule_node.replace_all_uses_with(input_node) + nodes_to_remove.append(submodule_node) + if submodule_node.target in DQ_OPS: + has_non_output_user = False + for user in copy.copy(submodule_node.users): + if user.op != "output": + has_non_output_user = True + else: + input_node = cast(Node, submodule_node.args[0]) + submodule_node.replace_all_uses_with(input_node) + arg_index = cast(list[Node], output_node.args[0]).index( + input_node + ) + quant_args = QuantArgs.from_operator( + submodule_node.target, submodule_node.args + ) + output_node.meta["input_qparams"][arg_index] = quant_args + + # Remove dq node if it only has the output node as its user. + if not has_non_output_user: + nodes_to_remove.append(submodule_node) + # Placeholders without users won't be retraced with correct dtype, do it manually. + # Control flow node input is matched to placeholder nodes in the submodule by index. + # This means it will break if another pass inserts a placeholder before this pass. + if submodule_node.op == "placeholder": + if len(submodule_node.users) == 0: + submodule_node.meta["val"] = args[arg_id].meta["val"] + arg_id += 1 + if arg_id > len(args): + raise RuntimeError( + "Submodule had more placeholders than calling node had inputs." + " This is probably due to a placeholder being inserted in a pass." + ) + for node_to_remove in nodes_to_remove: + submodule.graph.erase_node(node_to_remove) + return + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Loop over the graph nodes and find any node in the 'targeted_ops' list. @@ -181,8 +259,8 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 n.meta["input_qparams"] = {} n.meta["output_qparams"] = {} for i, arg in enumerate(n.args): - if isinstance(arg, list): - self.fold_and_annotate_arg(graph_module, n, arg, i) + if isinstance(arg, (list, tuple)): + self.fold_and_annotate_arg(graph_module, n, arg, i) # type: ignore elif isinstance(arg, Node): self.fold_and_annotate_arg(graph_module, n, [arg], i) @@ -211,6 +289,12 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 output_dtype = output_qparams[0].dtype set_node_arg(n, "dtype", output_dtype) + if n.target in ( + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + ): + self._handle_control_flow_node(n, graph_module) + # retrace the graph to update the fake tensor types graph_module = super().call(graph_module).graph_module @@ -218,7 +302,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 return PassResult(graph_module, True) -class QuantizeOperatorArguments(ArmPass): +class QuantizeClampArgumentsPass(ArmPass): """ This pass makes sure that the arguments to clamp.default are quantized correctly. More specifically, this pass: diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batch_norm2d_pass.py similarity index 99% rename from backends/arm/_passes/fuse_batchnorm2d_pass.py rename to backends/arm/_passes/fuse_batch_norm2d_pass.py index 250cac230d8..d9ae706f503 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batch_norm2d_pass.py @@ -26,7 +26,7 @@ from torch.nn.utils.fusion import fuse_conv_bn_weights -class FuseBatchnorm2DPass(ArmPass): +class FuseBatchNorm2dPass(ArmPass): """Fuses the pattern convolution -> batchnorm by updating the weights and bias of the convolution and removing the batchnorm. """ diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 2c8986114db..a574ef554ad 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -164,7 +164,7 @@ def call(self, graph_module): return PassResult(graph_module, True) -class ComputeConstantOpsAOT(ArmPass): +class ComputeConstantOpsAOTPass(ArmPass): """ Evaluates call_functions that produce constant tensor outputs and replaces them with placeholders. diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index f50216153a5..09e989cd3aa 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass +from executorch.backends.arm._passes.convert_to_clamp_pass import ConvertToClampPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, ) diff --git a/backends/arm/_passes/fuse_view_copy_transform_pass.py b/backends/arm/_passes/fuse_view_copy_transform_pass.py new file mode 100644 index 00000000000..cef3b408c24 --- /dev/null +++ b/backends/arm/_passes/fuse_view_copy_transform_pass.py @@ -0,0 +1,14 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform +from executorch.exir.pass_base import ExportPass + + +class FuseViewCopyTransformPass(ArmPass, FuseViewCopyTransform): + _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index ef5aa9625c7..de80d61bfbe 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass): # Key: op overload; Value: zero-based indices of positional args that must be i64. I64_INPUT_ARG_POSITIONS = { torch.ops.aten.one_hot.default: (0,), + torch.ops.aten.index_copy_.default: (2,), + torch.ops.aten.index_copy.default: (2,), } def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index c1a66323f33..9e69a1e7e53 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -369,3 +369,216 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.recompile() return PassResult(graph_module, modified) + + +class InsertControlFlowRescalesPass(ArmPass): + """The quantization parameters for tensors going into and coming out of a submodule are not guaranteed to + match the quantization parameters for the corresponding tensors inside the submodule. For example, cond has + different annotation on input and output, while the entire graph inside the submodule could be using shared + annotation. This pass solves this by inserting rescales in the beginning and end of the submodule + that transform the tensor from one set of quantization parameters to another. + + The pass is run by the graph_module containing the control flow operator, but requires that the affected nodes + inside the submodule have been q-dq folded and have input/output_qparams meta. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def _get_input_nodes(self, graph_module: GraphModule): + return [node for node in graph_module.graph.nodes if node.op == "placeholder"] + + def _insert_rescale( + self, + in_qparams: QuantArgs, + out_qparams: QuantArgs, + from_node: Node, + graph_module: GraphModule, + ): + """Insert a rescale into the graph, inheriting meta from `from_node`. + The node is not connected to anything, that is up to the user.""" + + new_scales = [ + in_qparams.get_scale_per_tensor() / out_qparams.get_scale_per_tensor() + ] + + rescale_node = create_node( + graph_module.graph, + exir_ops.backend.tosa.RESCALE.default, + ( + None, + out_qparams.dtype, + new_scales, + in_qparams.get_zp_per_tensor(), # Old zero point + out_qparams.get_zp_per_tensor(), # New zero point + ), + from_node=from_node, + ) + return rescale_node + + def _rescale_submodule_inputs( + self, submodule: GraphModule, input_qparams_map: Dict[int, QuantArgs] + ) -> bool: + """Insert rescales at the inputs of `submodule` to match the qparams outside the submodule. + Matching the correct qparams gets a bit tricky: + Containing module: | submodule: + ops => cond | => placeholders => ... + + The dq->q qparam pair we want to convert to a rescale is: + (input qparams of op, output qparams of placeholder) + And the rescale is inserted after the placeholder. + + Args: + submodule: GraphModule: the GraphModule in which to rescale the inputs. + input_qparams_map: A map of input indexes mapping to QuantArgs. Not guaranteed to contain a mapping + for every submodule input. + Returns: + True if at least one rescale was inserted, False otherwise. + """ + + modified = False + input_nodes = self._get_input_nodes(submodule) + for qargs_index in input_qparams_map: + input_node = input_nodes[qargs_index] + if len(input_node.users) == 0: + continue + if len(out_qparams_map := input_node.meta.get("output_qparams", {})) != 1: + raise ValueError( + f"Expected submodule input {input_node} to have exactly one output qparam, got {out_qparams_map}" + ) + in_qparams = input_qparams_map[qargs_index] + out_qparams = cast(QuantArgs, out_qparams_map[0]) + + # Remove qparam meta to not confuse folding pass. + del input_node.meta["output_qparams"] + if in_qparams == out_qparams: + continue + with submodule.graph.inserting_after(input_node): + modified = True + rescale_node = self._insert_rescale( + in_qparams, out_qparams, input_node, submodule + ) + input_node.replace_all_uses_with(replace_with=rescale_node) + rescale_node.update_arg(0, input_node) + return modified + + def _rescale_submodule_outputs( + self, submodule: GraphModule, output_qparams_map: Dict[int, QuantArgs] + ) -> bool: + """Insert rescales at the outputs of `submodule` to match the qparams outside the submodule. + Matching the correct qparams gets a bit tricky: + Submodule: | Containing module: + output_nodes => output |=> getitems => ... + + The dq->q qparam pair we want to convert to a rescale is: + (input qparam of output_node, output qparam of getitem) + And the rescale is inserted between op and output. Note that the output qparam of op is called input_qargs, + since the it is the input to the dq-q pair. + + Args: + submodule: GraphModule: the GraphModule in which to rescale the outputs. + output_qparams_map: A map of output indexes mapping to QuantArgs. Not guaranteed to contain a mapping + for every submodule output. + Returns: + True if at least one rescale was inserted, False otherwise. + """ + + modified = False + output_node = submodule.graph.output_node() + output_args = list(cast(tuple[Node], output_node.args[0])) + input_qparams_map = cast( + dict[int, QuantArgs], output_node.meta["input_qparams"] + ) + for qargs_index in output_qparams_map: + output_arg_node = output_args[qargs_index] + in_qparams = input_qparams_map[qargs_index] + out_qparams = output_qparams_map[qargs_index] + if in_qparams == out_qparams: + continue + with submodule.graph.inserting_before(output_node): + modified = True + rescale_node = self._insert_rescale( + in_qparams, out_qparams, output_arg_node, submodule + ) + output_args[qargs_index] = rescale_node + rescale_node.update_arg(0, output_arg_node) + output_node.update_arg(0, tuple(output_args)) + # Remove qparam meta to not confuse folding pass. + del output_node.meta["input_qparams"] + return modified + + def _get_input_qparams_map(self, node: Node, idx: int): + input_qparams_meta = cast( + dict[int, QuantArgs], node.meta.get("input_qparams", None) + ) + if input_qparams_meta: + input_qparams = cast(QuantArgs, input_qparams_meta.get(idx, None)) + if not input_qparams: + raise ValueError( + f"Expected entry with key {idx} in input_qparams meta, got {input_qparams_meta}" + ) + num_inputs = len(cast(list, node.args[idx])) + + # Currently, infra only supports one set of qparams for a list of inputs + # Map all inputs to the same qparams. + input_qparams_map = {i: input_qparams for i in range(num_inputs)} + return input_qparams_map + return None + + def _get_output_qparams_map(self, node: Node): + output_qparams_map: dict[int, QuantArgs] = {} + for getitem_node in node.users: + idx = cast(int, getitem_node.args[1]) + qparam = getitem_node.meta.get("output_qparams", None) + if qparam: + output_qparams_map[idx] = cast(QuantArgs, qparam[0]) + return output_qparams_map + + def _rescale_cond_submodules(self, node: Node, graph_module: GraphModule) -> bool: + modified = False + if_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore + else_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[2].target)) # type: ignore + input_qparams_map = self._get_input_qparams_map(node, 3) + if input_qparams_map: + modified |= self._rescale_submodule_inputs(if_graph, input_qparams_map) + modified |= self._rescale_submodule_inputs(else_graph, input_qparams_map) + + output_qparams_map = self._get_output_qparams_map(node) + if output_qparams_map: + modified |= self._rescale_submodule_outputs(if_graph, output_qparams_map) + modified |= self._rescale_submodule_outputs(else_graph, output_qparams_map) + return modified + + def _rescale_while_submodules(self, node: Node, graph_module: GraphModule): + modified = False + cond_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[0].target)) # type: ignore + body_graph: GraphModule = cast(GraphModule, graph_module.get_submodule(node.args[1].target)) # type: ignore + + input_qparams_map = self._get_input_qparams_map(node, 2) + if input_qparams_map: + modified |= self._rescale_submodule_inputs(cond_graph, input_qparams_map) + modified |= self._rescale_submodule_inputs(body_graph, input_qparams_map) + + output_qparams_map = self._get_output_qparams_map(node) + if output_qparams_map: + modified |= self._rescale_submodule_outputs(body_graph, output_qparams_map) + return modified + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + node = cast(Node, node) + if node.op != "call_function": + continue + + if node.target == torch.ops.higher_order.cond: + modified = self._rescale_cond_submodules(node, graph_module) + if node.target == torch.ops.higher_order.while_loop: + modified = self._rescale_while_submodules(node, graph_module) + + if modified: + # Retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index c09df48f7be..34ccdc82918 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -49,6 +49,7 @@ def __init__(self, exported_program: ExportedProgram) -> None: exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.bitwise_right_shift.Tensor, exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.eq.Tensor, @@ -57,6 +58,7 @@ def __init__(self, exported_program: ExportedProgram) -> None: exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.where.self, exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 9ff15e2850b..34634b99712 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -54,7 +54,10 @@ def call(self, graph_module: torch.fx.GraphModule): with graph.inserting_before(node): unsqueeze_before = create_node( - graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node + graph, + exir_ops.edge.aten.unsqueeze_copy.default, + from_node=node, + inherit_qparams=False, ) unsqueeze_before.args = ( input_node, # Input is node's original input @@ -68,6 +71,7 @@ def call(self, graph_module: torch.fx.GraphModule): graph, exir_ops.edge.aten.bmm.default, from_node=node, + inherit_qparams=True, ) bmm_node.args = node.args node.replace_all_uses_with(bmm_node) @@ -79,6 +83,7 @@ def call(self, graph_module: torch.fx.GraphModule): graph, exir_ops.edge.aten.squeeze_copy.dims, from_node=node, + inherit_qparams=False, ) squeeze_after.args = ( bmm_node, diff --git a/backends/arm/_passes/remove_getitem_pass.py b/backends/arm/_passes/remove_getitem_pass.py new file mode 100644 index 00000000000..3ce157d3fd8 --- /dev/null +++ b/backends/arm/_passes/remove_getitem_pass.py @@ -0,0 +1,14 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.transforms import remove_getitem_op +from executorch.exir.pass_base import ExportPass + + +class RemoveGetItemPass(ArmPass, remove_getitem_op.RemoveGetItemPass): + _passes_required_after: Set[Type[ExportPass]] = set() diff --git a/backends/arm/_passes/remove_graph_asserts_pass.py b/backends/arm/_passes/remove_graph_asserts_pass.py new file mode 100644 index 00000000000..a462c1182ee --- /dev/null +++ b/backends/arm/_passes/remove_graph_asserts_pass.py @@ -0,0 +1,18 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +from executorch.backends.arm._passes.arm_pass import ArmPass + +from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import ( + ConvertInt64ConstOpsToInt32Pass, +) +from executorch.exir.pass_base import ExportPass +from executorch.exir.passes import remove_graph_asserts_pass + + +class RemoveGraphAssertsPass(remove_graph_asserts_pass.RemoveGraphAssertsPass, ArmPass): + _passes_required_after: Set[Type[ExportPass]] = {ConvertInt64ConstOpsToInt32Pass} diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py index 7a42d08dd61..d1f58fe148c 100644 --- a/backends/arm/_passes/replace_inf_values_pass.py +++ b/backends/arm/_passes/replace_inf_values_pass.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class ReplaceInfValues(ArmPass): +class ReplaceInfValuesPass(ArmPass): """ Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. """ diff --git a/backends/arm/_passes/rewrite_conv2d_pass.py b/backends/arm/_passes/rewrite_conv2d_pass.py index 52feba5f8b9..8fa11a3f0cb 100644 --- a/backends/arm/_passes/rewrite_conv2d_pass.py +++ b/backends/arm/_passes/rewrite_conv2d_pass.py @@ -266,6 +266,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: op_target=target_op, args=conv2d_args, from_node=node, + inherit_qparams=True, ) bias_fake_tensor = get_first_fake_tensor(bias) if bias else None tosa_node_fake_tensor = target_op( diff --git a/backends/arm/_passes/rewrite_matmul.py b/backends/arm/_passes/rewrite_matmul.py index 410f0d62bff..298cfd17f0c 100644 --- a/backends/arm/_passes/rewrite_matmul.py +++ b/backends/arm/_passes/rewrite_matmul.py @@ -68,6 +68,7 @@ def call(self, graph_module): args=(x1, x2), kwargs={}, from_node=node, + inherit_qparams=True, ) node.replace_all_uses_with(tosa_matmul_node) graph_module.graph.erase_node(node) diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index e0ef1dbcf4a..cff241d33cf 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -11,6 +11,7 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.utils import get_resize_parameters from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -48,11 +49,14 @@ def call(self, graph_module): args=(x, output_size, align_corners, scale_factors), kwargs={"resize_mode": resize_mode}, from_node=node, + inherit_qparams=True, ) node.replace_all_uses_with(tosa_resize_node) graph_module.graph.erase_node(node) input_dtype = get_first_fake_tensor(x).dtype - if input_dtype == torch.int8 and resize_mode == "bilinear": + if ( + input_dtype == torch.int8 or input_dtype == torch.int16 + ) and resize_mode == "bilinear": input_size = get_first_fake_tensor(x).shape input_size_xy = input_size[2:] output_size = get_first_fake_tensor(node).shape @@ -71,6 +75,11 @@ def call(self, graph_module): exir_ops.backend.tosa.RESCALE.default, ) tosa_resize_node.replace_all_uses_with(rescale_node) + if input_dtype == torch.int16: + tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) + rescale_node.args = ( tosa_resize_node, output_dtype, diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 23e6ec422aa..9460c8f199a 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -113,17 +113,17 @@ def is_valid_operator(node: torch.fx.Node) -> bool: dilation = node.args[4] if len(node.args) >= 5 else 1 ceil_mode = node.args[5] if len(node.args) >= 6 else False - # Dilation should be handled first by DecomposeMaxPool2DPass + # Dilation should be handled first by DecomposeMaxPool2dPass if isinstance(dilation, int): if dilation > 1: raise ValueError( - "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2DPass been run?" + "Expected max_pool2d with dilation = 1, has DecomposeMaxPool2dPass been run?" ) else: dilation = cast(list, dilation) if dilation[0] > 1 or dilation[1] > 1: raise ValueError( - "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2DPass been run?" + "Expected max_pool2d with dilation = [1, 1], has DecomposeMaxPool2dPass been run?" ) # If using ceil mode for rounding, the input does not need adjusting @@ -207,7 +207,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: with graph_module.graph.inserting_before(node): last_node = cast(torch.fx.Node, parent_node) for args in slice_args: - slice_node = create_node(graph, slice_op, (last_node,) + args) + slice_node = create_node( + graph, slice_op, (last_node,) + args, from_node=node + ) last_node = slice_node node.replace_input_with(cast(torch.fx.Node, parent_node), last_node) modified_graph = True diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 956eb77b62c..7e998e3a436 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -299,6 +299,8 @@ def remove_dim_order_kwargs( def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: + if "val" not in node.meta: + continue node_data = get_first_fake_tensor(node).data self.remove_dim_order_kwargs(graph_module, node) diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py new file mode 100644 index 00000000000..12ef80ae70b --- /dev/null +++ b/backends/arm/common/annotation_meta.py @@ -0,0 +1,39 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping, Optional + + +@dataclass(frozen=True, init=False) +class ArmAnnotationInfo(dict): + """ + Dataclass wrapper that behaves like a dict so serialization can treat it as + a plain mapping, while still exposing a typed attribute for convenience. + """ + + quantized: bool + CUSTOM_META_KEY: str = "_arm_annotation_info" + + def __init__( + self, + value: Optional[Mapping[str, Any]] = None, + *, + quantized: Optional[bool] = None, + ) -> None: + if quantized is not None: + resolved = bool(quantized) + + elif isinstance(value, Mapping): + resolved = bool(value.get("quantized", False)) + + else: + raise TypeError( + "ArmAnnotationInfo expects a mapping with a 'quantized' entry or a keyword 'quantized'." + ) + dict.__init__(self, quantized=resolved) + object.__setattr__(self, "quantized", resolved) diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index a33531e2411..1075594e901 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -35,6 +35,7 @@ class DebugMode(Enum): _OUTPUT_FORMAT_KEY = "output_format" _DEBUG_ARTIFACT_KEY = "debug_artifact_path" _DEBUG_MODE_KEY = "dump_debug_info" + _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" def _set_compile_specs( self, @@ -42,12 +43,14 @@ def _set_compile_specs( compiler_flags: list[str], path_for_intermediates: str | None = None, tosa_debug_mode: DebugMode | None = None, + output_order_workaround: bool = True, ): """Set all values of dataclass directly.""" self.tosa_spec = tosa_spec self.compiler_flags = compiler_flags self.path_for_intermediates = path_for_intermediates self.tosa_debug_mode = tosa_debug_mode + self.output_order_workaround = output_order_workaround @classmethod def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 @@ -56,10 +59,15 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 compiler_flags: list[str] | None = None path_for_intermediates: str | None = None tosa_debug_mode: ArmCompileSpec.DebugMode | None = None + output_order_workaround: bool = True unknown_specs: dict[str, str] = {} for spec in compile_specs: key = spec.key - val = spec.value.decode() + val = ( + spec.value.decode() + if isinstance(spec.value, (bytes, bytearray)) + else spec.value + ) if key == ArmCompileSpec._TOSA_SPEC_KEY: if tosa_spec is not None: raise ValueError("More than one tosa_spec entry in compile spec.") @@ -88,6 +96,8 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 "More than one tosa_debug_mode entry in compile spec." ) tosa_debug_mode = ArmCompileSpec.DebugMode[val] + elif key == ArmCompileSpec._OUTPUT_REORDER_KEY: + output_order_workaround = val # type: ignore[assignment] else: unknown_specs[key] = val @@ -109,6 +119,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 compiler_flags=compiler_flags, path_for_intermediates=path_for_intermediates, tosa_debug_mode=tosa_debug_mode, + output_order_workaround=output_order_workaround, ) cls.from_list_hook(compile_spec, unknown_specs) compile_spec.validate() @@ -170,6 +181,14 @@ def to_list(self): ) ) + if not self.output_order_workaround: + compile_spec.append( + CompileSpec( + ArmCompileSpec._OUTPUT_REORDER_KEY, + self.output_order_workaround, + ) + ) + return compile_spec def get_intermediate_path(self) -> str | None: @@ -201,6 +220,13 @@ def dump_debug_info(self, debug_mode: DebugMode | None): self.tosa_debug_mode = debug_mode return self + def set_output_order_workaround(self, output_order_workaround: bool): + self.output_order_workaround = output_order_workaround + return self + + def get_output_order_workaround(self) -> bool: + return self.output_order_workaround + @classmethod @abstractmethod def get_output_format(cls) -> str: diff --git a/backends/arm/ethosu/backend.py b/backends/arm/ethosu/backend.py index c2feab6478b..bd6da08dc38 100644 --- a/backends/arm/ethosu/backend.py +++ b/backends/arm/ethosu/backend.py @@ -9,6 +9,7 @@ # backends. Converts via TOSA as an intermediate form supported by AoT and # JIT compiler flows. # +"""Ahead-of-time Arm Ethos-U backend built on the shared TOSA pipeline.""" import logging from typing import final, List @@ -27,19 +28,28 @@ @final class EthosUBackend(BackendDetails): - """ - BackendDetails subclass for delegation to Ethos-U. Deduce the TOSA lowering from - the compile spec list by filtering out the compile spec values that are of interest - for the TOSABackend. + """BackendDetails subclass for delegation to Ethos-U. + + Deduce the TOSA lowering from the compile spec list by filtering out the + compile spec values that are of interest for the TOSABackend. + """ @staticmethod def _compile_tosa_flatbuffer( tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec ) -> bytes: - """ - Static helper method to do the compilation of the TOSA flatbuffer - representation to a target specific binary stream. + """Compile a TOSA flatbuffer into a target-specific binary stream. + + Args: + tosa_flatbuffer (bytes): Serialized TOSA graph produced by + ``TOSABackend``. + compile_spec (EthosUCompileSpec): Compile specification providing + Vela flags and intermediate paths. + + Returns: + bytes: Target-specific binary stream produced by Vela. + """ compile_flags = compile_spec.compiler_flags @@ -63,7 +73,7 @@ def _compile_tosa_flatbuffer( binary = vela_compile( tosa_flatbuffer, compile_flags, - verbose=logger.getEffectiveLevel() == logging.INFO, + verbose=logger.getEffectiveLevel() <= logging.INFO, intermediate_path=compile_spec.get_intermediate_path(), ) return binary @@ -73,6 +83,17 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: + """Lower the exported program and compile it for an Ethos-U target. + + Args: + edge_program (ExportedProgram): Program to lower to Ethos-U. + compile_specs (List[CompileSpec]): Serialized Ethos-U compile specs + supplied by the frontend. + + Returns: + PreprocessResult: Result containing the compiled Ethos-U binary. + + """ logger.info(f"{EthosUBackend.__name__} preprocess") compile_spec = EthosUCompileSpec.from_list(compile_specs) diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py index 9b6289156fa..b53034c365e 100644 --- a/backends/arm/ethosu/compile_spec.py +++ b/backends/arm/ethosu/compile_spec.py @@ -15,16 +15,7 @@ class EthosUCompileSpec(ArmCompileSpec): - """ - Compile spec for Ethos-U NPU. - - Args: - target: Ethos-U accelerator configuration, e.g. ethos-u55-128. - system_config: System configuration to select from the Vela configuration file. - memory_mode: Memory mode to select from the Vela configuration file. - extra_flags: Extra flags for the Vela compiler. - config_ini: Vela configuration file(s) in Python ConfigParser .ini file format. - """ + """Compile specification for Ethos-U NPU targets.""" _TARGET_KEY = "target" @@ -36,6 +27,21 @@ def __init__( extra_flags: list[str] | None = None, config_ini: str | None = "Arm/vela.ini", ): + """Normalise Ethos-U compile configuration and compiler flags. + + Args: + target (str): Ethos-U accelerator configuration (for example, + ``"ethos-u55-128"``). + system_config (str | None): System configuration name from the Vela + config file. Defaults based on ``target`` when omitted. + memory_mode (str | None): Memory mode selection from the Vela config + file. Defaults based on ``target`` when omitted. + extra_flags (list[str] | None): Additional command-line flags for + Vela. + config_ini (str | None): Path to a Vela .ini configuration file. + Defaults to ``"Arm/vela.ini"``. + + """ self.target = target # Set vela compiler flags @@ -78,16 +84,18 @@ def __init__( self.validate() def to_list(self): + """Return compile specs including the encoded Ethos-U target.""" compile_specs = super().to_list() compile_specs.append(CompileSpec(self._TARGET_KEY, self.target.encode())) return compile_specs @classmethod def from_list_hook(cls, compile_spec, specs: dict[str, str]): + """Restore target-specific metadata from serialized compile specs.""" compile_spec.target = specs.get(cls._TARGET_KEY, None) def validate(self): - """Throws an error if the compile spec is not valid.""" + """Validate the configuration against supported Ethos-U settings.""" if len(self.compiler_flags) == 0: raise ValueError( "compile_flags are required in the CompileSpec list for EthosUBackend" @@ -99,4 +107,5 @@ def validate(self): @classmethod def get_output_format(cls) -> str: + """Return the artifact format emitted by this compile spec.""" return "vela" diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index f3c50ee3719..01d936be7ce 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,6 +6,7 @@ from . import ( # noqa clone_dim_order_support, + control_flow_support, convolution_support, embedding_support, ethos_u55_support, diff --git a/backends/arm/operator_support/clone_dim_order_support.py b/backends/arm/operator_support/clone_dim_order_support.py index 1397b74bf38..ae6445c050c 100644 --- a/backends/arm/operator_support/clone_dim_order_support.py +++ b/backends/arm/operator_support/clone_dim_order_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for dim-order clone in TOSA. + +This module registers a support check for ``dim_order_ops._clone_dim_order`` +ensuring input/output dtypes match and the value types are FakeTensors. + +""" import logging @@ -19,6 +25,8 @@ @register_tosa_support_check class CloneSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``_clone_dim_order``.""" + targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default] tosa_specs = [ @@ -29,6 +37,12 @@ class CloneSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: + """Return True if the node is supported by TOSA. + + Verify the operator target, the number and types of inputs/outputs, and + check that input and output dtypes match. + + """ if node.target not in self.targets: self.reporter.report_reject(node, f"Target {node.target} is not supported.") return False diff --git a/backends/arm/operator_support/control_flow_support.py b/backends/arm/operator_support/control_flow_support.py new file mode 100644 index 00000000000..24fa34f3462 --- /dev/null +++ b/backends/arm/operator_support/control_flow_support.py @@ -0,0 +1,162 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing +from typing import cast + +import torch +import torch.fx as fx + +from executorch.backends.arm._passes.arm_pass_utils import is_submodule_node +from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import Tosa_1_00 +from executorch.exir import ExportedProgram +from executorch.exir.backend.utils import WhyNoPartitionReporter + +from torch.fx.passes.operator_support import OperatorSupportBase + + +def _fully_partitioned(submodule: fx.GraphModule) -> bool: + partition_tag = None + for submodule_node in submodule.graph.nodes: + if submodule_node.op == "call_function": + # Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported. + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + continue + if ( + submodule_node.target in DQ_OPS + and list(submodule_node.users)[0].op == "output" + ): + continue + if "delegation_tag" not in submodule_node.meta: + return False + if partition_tag is None: + partition_tag = submodule_node.meta["delegation_tag"] + elif submodule_node.meta["delegation_tag"] != partition_tag: + return False + return True + + +def _submodules_fully_partitioned( + node: fx.Node, exported_program: ExportedProgram +) -> bool: + """Returns whether the submodule arguments to a cond node were fully partitioned. + Updates "val" meta of the submodules if they are. + """ + match node.target: + case torch.ops.higher_order.cond: + submodule_args = node.args[1:3] + case torch.ops.higher_order.while_loop: + submodule_args = node.args[0:2] + case _: + raise ValueError(f"Unexpected target: {node.target}") + cond_submodules = ( + ( + exported_program.graph_module.get_submodule( + str(cast(torch.fx.Node, submodule_node).target) + ), + cast(torch.fx.Node, submodule_node), + ) + for submodule_node in submodule_args + ) + for submodule, submodule_node in cond_submodules: + submodule = cast(torch.fx.GraphModule, submodule) + + if _fully_partitioned(submodule): + submodule_node.meta["val"] = submodule.graph.output_node().meta["val"] + else: + return False + return True + + +def _tosa_spec_supports_cf(tosa_spec: TosaSpecification) -> bool: + if not isinstance(tosa_spec, Tosa_1_00): + return False + return tosa_spec.support_extension("cf") + + +class ControlFlowSubmoduleSupported(OperatorSupportBase): + """Check whether control flow submodule args should be partitioned. + Applies control-flow extension constraints before allowing delegation.""" + + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + reporter: WhyNoPartitionReporter, + ): + self.exported_program = exported_program + self.reporter = reporter + self.tosa_spec = tosa_spec + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if is_submodule_node(node): + if not _tosa_spec_supports_cf(self.tosa_spec): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + for user in node.users: + if user.target not in ControlFlowOpSupported._targeted_ops: + self.reporter.report_reject( + node, f"Submodule had unsupported user {user}" + ) + return False + if not _submodules_fully_partitioned(user, self.exported_program): + self.reporter.report_reject( + node, "One submodule was not fully partitioned" + ) + return False + return True + return False + + +class ControlFlowOpSupported(OperatorSupportBase): + """Check whether control flow ops should be partitioned. + Applies control-flow extension constraints before allowing delegation.""" + + _targeted_ops = { + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + } + + def __init__( + self, + exported_program: ExportedProgram, + tosa_spec: TosaSpecification, + reporter: WhyNoPartitionReporter, + ): + self.exported_program = exported_program + self.reporter = reporter + self.tosa_spec = tosa_spec + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.target in self._targeted_ops: + if not _tosa_spec_supports_cf(self.tosa_spec): + self.reporter.report_reject( + node, + f"TOSA spec {self.tosa_spec} does not support control flow extension.", + ) + return False + + if not _submodules_fully_partitioned(node, self.exported_program): + self.reporter.report_reject( + node, "Submodule was not fully partitioned." + ) + return False + return True + + return False diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index f6fdada7d52..6c171e101aa 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -18,6 +18,9 @@ import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.convert_permute_singleton_to_view_pass import ( + is_singleton_permutation, +) from executorch.backends.arm._passes.insert_table_ops import TableOps from executorch.backends.arm.operators.op_permute import transform_permutation_vector from executorch.backends.arm.tosa.utils import tosa_shape @@ -430,10 +433,17 @@ def _permute_constraint_i8_i16( ) -> bool: """Return True if permutation meets i8/i16 constraints.""" N, H, W, C = nhwc_shape + + if is_singleton_permutation(nhwc_shape, permutation): + return True + match permutation: case (0, 1, 2, 3): # NHWC -> NHWC return True - case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH + case ( + (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2) | (0, 2, 3, 1) | (0, 3, 2, 1) + ): + # NHWC -> NWHC, NHCW, NCWH, NCHW, NCHW -> NHWC return N * H <= 65536 and W <= 65536 and C <= 65536 case _: return self.axes_product(nhwc_shape) <= 65536 diff --git a/backends/arm/operator_support/minmax_support.py b/backends/arm/operator_support/minmax_support.py index 68433819f4b..8ba5d9335dc 100644 --- a/backends/arm/operator_support/minmax_support.py +++ b/backends/arm/operator_support/minmax_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for min/max along a dimension in TOSA. + +Provide support checks ensuring that argmax/argmin indices are not consumed, +restricting to float profiles until index quantization is supported. + +""" import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -14,6 +20,8 @@ @register_tosa_support_check class MinMaxSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.max.dim`` and ``aten.min.dim``.""" + targets = [ exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim, @@ -24,7 +32,16 @@ class MinMaxSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + + Allow max/min when the argmax/argmin output is unused or dropped (i.e., + only the value is consumed). Disallow cases where arg indices are + further used. + + """ if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]: no_argmax = len(node.users) == 1 no_argmax_users = (len(node.users) == 2) and ( diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index 76d1ba7bf36..02e9e0db90e 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -2,7 +2,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.sum.dim_IntList`` in TOSA. +Provide shape constraints for U55 subsets; otherwise allow reductions. + +""" from typing import cast import torch.fx as fx @@ -16,6 +20,8 @@ @register_tosa_support_check class SumSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for sum over dimensions.""" + targets = [exir_ops.edge.aten.sum.dim_IntList] tosa_specs = [ @@ -23,7 +29,16 @@ class SumSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + + On U55 subsets, enforce bounds on the reduced dimension and the products + of sizes before/after the reduction axis. On other targets, accept the + operation unconditionally. + + """ if not tosa_spec.is_U55_subset: return True diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index 82c4387fc85..7670edec0a9 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -48,5 +48,5 @@ def is_node_tosa_supported( """ # TODO MLETORCH-525 Remove warning if tosa_spec.is_U55_subset: - logging.warning(f"{node.target} may introduce one-off errors.") + logger.warning(f"{node.target} may introduce one-off errors.") return True diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index 14ca505635c..178f8dddb18 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -2,7 +2,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.slice_copy`` in TOSA. +Support slicing with unit step only; emit a warning and reject otherwise. + +""" import logging @@ -19,6 +23,8 @@ @register_tosa_support_check class SliceCopySupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.slice_copy``.""" + targets = [exir_ops.edge.aten.slice_copy.Tensor] tosa_specs = [ @@ -26,12 +32,20 @@ class SliceCopySupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc] + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + """Return True if the node is supported by TOSA. + + Accept slice_copy when the step is 1 (or unspecified). Warn and reject + non-unit step sizes. + + """ if tosa_spec not in self.tosa_specs: return False args = node.args if len(args) == 5 and (step := args[4]) != 1: - logging.warning(f"{node.target} with step size of {step} not supported.") + logger.warning(f"{node.target} with step size of {step} not supported.") return False return True diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index 181796b97fe..48f0c4d8604 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -52,7 +52,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck): @staticmethod def _merge_supported_types( - # pyre-ignore[11] dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict, ) -> SupportedTypeDict: diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 3a1d11eab8c..7c0d2e9907c 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide operator-support checks and registries for TOSA delegation. + +Define a base check class, a registry/dispatcher, and several generic checks +used by the TOSA partitioner to decide if FX nodes are eligible for delegation. + +""" import itertools @@ -12,13 +18,23 @@ import torch import torch.fx as fx -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor -from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.backends.arm._passes.arm_pass_utils import ( + get_first_fake_tensor, + is_submodule_node, +) +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOTPass, +) from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS +from executorch.backends.arm.operator_support.control_flow_support import ( + ControlFlowOpSupported, + ControlFlowSubmoduleSupported, +) from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55CastCheck, EthosU55DtypeSupport, @@ -30,7 +46,10 @@ TOSA_PRO_FP_SupportList, TOSA_PRO_INT_SupportList, ) -from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.specification import ( + TosaSpecification, + TosaSpecMapping, +) from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -42,15 +61,31 @@ class SupportedTOSAOperatorCheck(OperatorSupportBase): - """ - Supported OP for TOSA lowering + """Provide a base operator-support check for TOSA lowering. + + Subclasses should implement :py:meth:`is_node_tosa_supported` and declare + the class attributes below to indicate what they support. + + Attributes: + targets (list[OpOverload]): Operator overloads supported by this + check. + tosa_specs (list[TosaSpecification]): TOSA specs where the check is + applicable. + """ def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter): + """Initialize the check with a TOSA spec and reporter. + + Args: + tosa_spec (TosaSpecification): Active TOSA specification. + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ self.tosa_spec = tosa_spec self.reporter = reporter - # Should be populated by subclass implementation + # Class attributes populated by subclasses tosa_specs: list[TosaSpecification] = [] targets: list[str] = [] @@ -58,6 +93,17 @@ def __init__(self, tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporte def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if the node matches targets and subclass-specific checks. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported program + modules. + node (fx.Node): Node to evaluate. + + Returns: + bool: True if both the target and TOSA-specific checks pass. + + """ if node.target not in self.targets: return False return self.is_node_tosa_supported(node, self.tosa_spec) @@ -65,39 +111,59 @@ def is_node_supported( def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: - """ - Checks if the fx.Node node is lowerable using the TOSA specification defined by tosa_spec. + """Check if the node is lowerable under the given TOSA spec. + + Args: + node (fx.Node): FX node to check. + tosa_spec (TosaSpecification): Active TOSA specification. + + Returns: + bool: True if supported; otherwise, False. + """ raise NotImplementedError("SupportedTOSAOperatorCheck must be extended.") # container for all SupportedTosaOperatorCheck classes -_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = { - TosaSpecification.create_from_string("TOSA-1.0+INT"): [], - TosaSpecification.create_from_string("TOSA-1.0+FP"): [], -} +_tosa_spec_support: TosaSpecMapping[Type[SupportedTOSAOperatorCheck]] = ( + TosaSpecMapping() +) def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): - """ - Decorator to mark a subclass implmentation of SupportedTosaOperatorCheck - to be registered for checking if a torch.fx.Node is lowerable given - a TOSA specification. + """Register an operator-support checker for one or more TOSA specs. + + Decorate subclasses of :py:class:`SupportedTOSAOperatorCheck` so they are + picked up by the factory and partitioner for the specs declared in their + ``tosa_specs`` class attribute. + + Args: + checker (Type[SupportedTOSAOperatorCheck]): Checker class to register. + """ for tosa_spec in checker.tosa_specs: - _tosa_spec_support[tosa_spec].append(checker) + _tosa_spec_support.add(tosa_spec, checker) return checker def get_registered_tosa_support_checks( tosa_spec: TosaSpecification, ) -> list[Type[SupportedTOSAOperatorCheck]]: - if tosa_spec not in _tosa_spec_support: + """Get all registered operator-support checkers for a given spec. + + Args: + tosa_spec (TosaSpecification): TOSA spec to query. + + Returns: + list[Type[SupportedTOSAOperatorCheck]]: Registered checker classes. + + """ + checks = _tosa_spec_support.get(tosa_spec) + if not checks: raise RuntimeError( - f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}" + f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support._mapping.keys())}" ) - - return _tosa_spec_support[tosa_spec] + return checks def tosa_support_factory( @@ -106,11 +172,27 @@ def tosa_support_factory( reporter: WhyNoPartitionReporter, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> OperatorSupportBase: - """Generates an OperatorSupport class depending on the given `tosa_spec`. - Additional checks can be supplied to avoid partitioning additional nodes. + """Create an OperatorSupport composite for a TOSA spec. + + Combine profile-specific positive checks, registered operator checks, and + negative checks into a single :py:class:`OperatorSupportBase` chain. + + Args: + tosa_spec (TosaSpecification): Active TOSA specification. + exported_program (ExportedProgram): Program context for checks. + reporter (WhyNoPartitionReporter): Reporter for rejections. + additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra + negative checks to apply. + + Returns: + OperatorSupportBase: Composite checker for the given spec. + """ # Postive checks: Add nodes to partitioning - positive_checks: list[OperatorSupportBase] = [] + positive_checks: list[OperatorSupportBase] = [ + ControlFlowSubmoduleSupported(exported_program, tosa_spec, reporter), + ControlFlowOpSupported(exported_program, tosa_spec, reporter), + ] if tosa_spec.support_integer(): positive_checks.append(TOSAProINTSupportList()) @@ -134,6 +216,7 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): + negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) @@ -152,37 +235,116 @@ def tosa_support_factory( class TOSAProINTSupportList(OperatorSupportBase): - """ - TOSA_PRO_INT_SupportList: - Ops supported in INT profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOps. - Note that ops supported via pre-quantization decompositions are not included here. + """Provide the INT profile support list for TOSA. + + TOSA_PRO_INT_SupportList enumerates ops supported in the INT profile via + native TOSA ops, decompositions, pre-compute steps, or TableOps. + + Note: + Ops supported via pre-quantization decompositions are not included + here. + """ def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - + """Return True if the node is in the INT profile support list.""" return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList class TOSAProFPSupportList(OperatorSupportBase): + """Provide the FP profile support list for TOSA. + + Includes ops supported natively, via decomposition/transformation, and pre- + compute. + + """ + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + """Return True if the node is in the FP profile support list.""" + return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList + + +class CheckArmQuantized(OperatorSupportBase): """ - TOSA_PRO_FP_SupportList: - Ops supported in FP profile via native TOSA ops, decomposition/transformation, pre-compute + Check if the node was marked as quantized in the Arm backend. + This is used to ensure that nodes that were quantized in the Arm backend + are only partitioned if they are supported by the TOSA backend. """ + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + + def _is_quantized(self, node: torch.fx.Node) -> bool: + """Checks if the node is quantized. + + A node is considered quantized if at least one criteria is met: + - Its dtype is not floating point or complex => integer + - It is one of the special cases where the node has been created in to_edge, e.g. + .Scalar operations that have been promoted .Tensor operations + where the scalar is replaced by a full op. + - It has been marked as quantized in the ArmAnnotationInfo custom meta. + + Args: + node (torch.fx.Node): The FX node to check. + + Returns: + bool: True if the node is quantized, False otherwise. + """ + node_dtype = get_first_fake_tensor(node).dtype + if not node_dtype.is_complex and not node_dtype.is_floating_point: + return True + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOTPass.targeted_ops, + ): + # Special cases where nodes have been created in to_edge, e.g. + # .Scalar operations that have been promoted .Tensor operations + # where the scalar is replaced by a full op. + if all(user.target in Q_OPS for user in node.users): + return True + for user in node.users: + if ( + user.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + dim_order_dtype = get_first_fake_tensor(user).dtype + if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: + return False + else: + return False + return True + return ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and ArmAnnotationInfo( + node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY] + ).quantized + ) + def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList + if node.target in (*DQ_OPS, *Q_OPS): + return True + + if not self._is_quantized(node): + self.reporter.report_reject( + node, "Node was not marked as quantized in the Arm backend." + ) + return False + return True class CheckProperQuantization(OperatorSupportBase): - """ - For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize - and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale - activations. + """Ensure targeted nodes are properly quantized. + + Verify that a pair of quantize/dequantize nodes surrounds targeted ops so + rescaling and table operators behave correctly. + """ targeted_ops = ( @@ -208,13 +370,28 @@ class CheckProperQuantization(OperatorSupportBase): ) def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter.""" self.reporter = reporter def _is_matmul_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ): - """ - Find the matmul source partition containing this node and check that all its inputs and outputs are quantized. + """Check quantization for decomposed matmul partitions. + + Handles an edge case where the quantized pipeline + `dq -> torch.matmul/operator.matmul -> q` decomposes into + `dq -> expand -> view -> aten.mm -> view -> q`. + + Args: + submodules (Mapping[str, torch.nn.Module]): Map of child modules to + inspect for matmul partitions. + node (fx.Node): Node that should belong to a quantized matmul + partition. + + Returns: + bool: True if the matched partition uses quantized inputs and + outputs. + """ for graph_module in submodules.values(): graph_module = typing.cast(fx.GraphModule, graph_module) @@ -263,6 +440,12 @@ def _is_matmul_node_supported( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if the node passes constant-cast and multi-output checks. + + Ensures decomposition-specific matmul partitions keep quantized inputs + and outputs. + + """ output_quantized = False input_quantized = False if node.target not in self.targeted_ops: @@ -314,21 +497,22 @@ def is_node_supported( class CheckInt64InputsAndOutputs(OperatorSupportBase): - """TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned. - There are however some exceptions: - - Nodes with int64 output can be partitioned if they are constant, within int32, - and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT. - - Nodes with int64 output can be partitioned if all users are getitem with non-int64 output. - In this case, there are multiple outputs and the int64 ones are not used. - - Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant - ops fulfilling the criteria above. - Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned - or not. + """Reject general int64 tensors while allowing safe exceptions. + + Exceptions are: + - Nodes with contant int64 output within int32 range that are cast away + from int64 by all users. + - Int64 output where all users are getitem nodes with non-int64 outputs. + In this case there are multiple outputs and the int64 output is unused. + - Nodes where all inputs are int64 constant placeholders or constant ops + that fulfill the above exceptions. + """ def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): + """Initialize the check with program context and reporter.""" self.input_names = [ spec.arg.name for spec in exported_program.graph_signature.input_specs @@ -350,7 +534,9 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - + """Return True when int64 use is absent or safe per exceptions.""" + if is_submodule_node(node): + return True vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] @@ -364,7 +550,7 @@ def is_node_supported( for output_node in node.users ) if ( - node.target in ComputeConstantOpsAOT.targeted_ops + node.target in ComputeConstantOpsAOTPass.targeted_ops and users_output_non_int64 ): if not self.inside_int32_bounds(node): @@ -390,7 +576,11 @@ def is_node_supported( # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned. # If it is not partitioned, the partition will get an int64 input and fail. - for input_node in node.all_input_nodes: + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor_in = get_first_fake_tensor(input_node) if tensor_in.dtype != torch.int64: continue @@ -402,7 +592,7 @@ def is_node_supported( continue # Constant operator if input_node.op == "call_function": - if input_node.target in ComputeConstantOpsAOT.targeted_ops: + if input_node.target in ComputeConstantOpsAOTPass.targeted_ops: # This is not perfect since the input_node can still be rejected by other checks but # this should cover the majority of cases. if self.is_node_supported({}, input_node): @@ -416,18 +606,30 @@ def is_node_supported( class CheckFloat64Inputs(OperatorSupportBase): + """Reject nodes with float64 inputs. + + Useful as a negative check for specs that do not allow float64. + + """ def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): + """Initialize the check with program context and reporter.""" self.reporter = reporter super().__init__() def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - - for input_node in node.all_input_nodes: + """Return True if no float64 inputs are present.""" + if is_submodule_node(node): + return True + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.float64: self.reporter.report_reject( @@ -439,9 +641,10 @@ def is_node_supported( class RankCheck(OperatorSupportBase): - """Makes sure that nodes with input or output tensors with rank > max_rank are not partitioned""" + """Reject nodes with rank greater than ``max_rank``.""" def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int): + """Initialize the check with a reporter and maximum rank.""" self.reporter = reporter self.max_rank = max_rank super().__init__() @@ -449,7 +652,14 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - input_nodes = node.all_input_nodes + """Return True if input/output tensor ranks are within the limit.""" + if is_submodule_node(node): + return True + input_nodes = ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ) # check if any input node has an unsupported rank for input_node in input_nodes: input_node_shape = get_first_fake_tensor(input_node).shape diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index afe1c4dd22c..38eb9e7cad9 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -24,7 +24,6 @@ runtime.python_library( ":node_visitor", ":operator_validation_utils", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/_passes:passes", "//executorch/exir:lib", diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a180d0a6e86..b987e99cf4f 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Import and register Arm TOSA operator visitors. + +Importing this package loads all visitor modules so their classes can be +registered via decorators and discovered at runtime. + +""" from . import ( # noqa @@ -16,6 +22,7 @@ op_cat, op_ceil, op_clamp, + op_cond_if, op_constant_pad_nd, op_cos, op_eq, @@ -57,6 +64,7 @@ op_tosa_transpose, op_view, op_where, + op_while, ops_binary, ops_identity, ) diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 682c849fe80..c54ae67e541 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -3,8 +3,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide utilities to register and apply TOSA node visitors. + +Use this module to construct and serialize TOSA operators from FX nodes. +- Define the NodeVisitor base class and registry +- Register concrete visitors per TOSA specification + +""" import json + +import logging from typing import Any, Dict, List, Optional import torch @@ -13,13 +22,23 @@ from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.debug.schema import DebugHook from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.arm.tosa.specification import ( + TosaSpecification, + TosaSpecMapping, +) from torch.export import ExportedProgram +logger = logging.getLogger(__name__) + class NodeVisitor: - """ - Node Visitor pattern for lowering edge IR to TOSA + """Provide a visitor pattern to lower edge IR to TOSA. + + Attributes: + _exported_program (torch.export.ExportedProgram): Source program being lowered. + tosa_spec (TosaSpecification): Active TOSA specification for lowering. + debug_hook (Optional[DebugHook]): Optional hook for debug metadata. + """ # Add the currently supported node_visitor specs as default. @@ -51,6 +70,23 @@ def _serialize_operator( outputs: List[str], attributes: Optional[Any] = None, ) -> None: + """Serialize a TOSA operator into the graph. + + When a ``DebugHook`` is active, attach location metadata (in JSON) to + the operator for traceability. + + Args: + node (torch.fx.Node): Source FX node being lowered. + tosa_graph: Target TOSA serializer/graph object. + tosa_op: TOSA operator enum value to emit. + inputs (List[str]): Names of input tensors. + outputs (List[str]): Names of output tensors. + attributes (Optional[Any]): Optional TOSA attribute object. + + Returns: + None: Mutates ``tosa_graph`` in place. + + """ op_location = ts.TosaOpLocation() if self.debug_hook: debug_info = self.debug_hook.add( @@ -77,25 +113,50 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + """Define a TOSA operator node. + + Args: + node (torch.fx.Node): FX node being lowered. + tosa_graph (serializer.tosa_serializer.TosaSerializer): Target TOSA graph. + inputs (List[TosaArg]): Input tensor arguments. + output (TosaArg): Output tensor descriptor. + + Returns: + None: Mutates ``tosa_graph`` in place. + + Raises: + ValueError: If input count or dtypes are invalid. + + """ raise NotImplementedError("NodeVisitor must be extended.") # container for all node visitors -_node_visitor_dicts: Dict[TosaSpecification, Dict] = { - TosaSpecification.create_from_string("TOSA-1.0+INT"): {}, - TosaSpecification.create_from_string("TOSA-1.0+FP"): {}, -} +_node_visitor_tuples: TosaSpecMapping[tuple] = TosaSpecMapping() def register_node_visitor(visitor): + """Register a concrete ``NodeVisitor`` class for its TOSA specs.""" for tosa_spec in visitor.tosa_specs: - _node_visitor_dicts[tosa_spec][visitor.target] = visitor + # Try to get the tuple to make sure it doesn't exist + visitor_tuple = (visitor.target, visitor) + try: + tuples = _node_visitor_tuples.get(tosa_spec) + except KeyError: + tuples = [] + + if visitor_tuple in tuples: + raise RuntimeError( + f"Visitor for target {visitor.target} already registered for TOSA spec {tosa_spec}" + ) + _node_visitor_tuples.add(tosa_spec, visitor_tuple) return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: - node_visitors = {} - tosa_spec = None + """Return a mapping from target names to visitor instances for a spec.""" + node_visitors: Dict[str, NodeVisitor] = {} + tosa_spec: TosaSpecification | None = None for arg in args: if isinstance(arg, TosaSpecification): tosa_spec = arg @@ -104,7 +165,13 @@ def get_node_visitors(*args) -> Dict[str, NodeVisitor]: if tosa_spec is None: raise RuntimeError("No TOSA specification supplied.") - for target, visitor in _node_visitor_dicts[tosa_spec].items(): + # Use the mapping to get the dict for this spec (handles combined specs) + for node_visitor_tuple in _node_visitor_tuples.get(tosa_spec): + target, visitor = node_visitor_tuple + if target in node_visitors and node_visitors[target].__class__ != visitor: + logger.warning( + f"Target {target} already has visitor class {node_visitors[target].__class__.__name__} registered, overwriting with class: {visitor.__name__}" + ) node_visitors[target] = visitor(*args) return node_visitors diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 3cbdd91d2e4..2a850c0cf52 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -46,7 +46,7 @@ def define_node( ) # process the negative index keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) if not keep_dim: - raise ValueError("This case should be handled by ConvertAnyDimDimsPass") + raise ValueError("This case should be handled by DecomposeAnyPass") attr = ts.TosaSerializerAttribute() attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim)) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 29d5c3bf635..693f3f1155a 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -118,11 +118,11 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], output.tosa_spec, ) - if inputs[0].dtype == ts.DType.INT8: + if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: accumulator_type = ts.DType.INT32 input_qargs = get_input_qparams(node) input_zp = input_qargs[0].get_zp_per_tensor() diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index c98da6aae8f..badf76c9384 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -74,6 +74,8 @@ def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes: return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist() elif dtype == torch.int8: return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.int16: + return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist() else: raise ValueError(f"Unsupported dtype for to_bytes: {dtype}") @@ -89,12 +91,12 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32], output.tosa_spec, ) node_input_dtype = node.meta["val"].dtype - # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments + # NOTE: Quantization of the min/max arguments is handled by QuantizeClampArgumentsPass min_val, max_val = self._get_min_max_arguments(node, node_input_dtype) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_cond_if.py b/backends/arm/operators/op_cond_if.py new file mode 100644 index 00000000000..4cf5120de31 --- /dev/null +++ b/backends/arm/operators/op_cond_if.py @@ -0,0 +1,56 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import Any, cast, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( # type: ignore + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_cf_extension, + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore +from torch.fx import Node + + +@register_node_visitor +class CondVisitor(NodeVisitor): + target = "cond" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + validate_num_inputs(self.target, inputs, 4) + validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec) + validate_cf_extension(self.target, self.tosa_spec) + + attr = ts.TosaSerializerAttribute() + if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3]) + attr.CondIfAttribute(if_graph, else_graph) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.COND_IF, + [ + inputs[0].name, + *(subgraph_input.name for subgraph_input in inputs[-1].special), + ], + output.multiple_output_names, + attr, + ) diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 3bda87af5ed..47d11fb5627 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -50,6 +50,7 @@ def define_node( [inputs[0], output], [ ts.DType.INT8, + ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL, @@ -62,6 +63,11 @@ def define_node( qargs = input_qparams[0] pad_const_val = qargs.quantize_value(inputs[2].number).item() pad_const_dtype = ts.DType.INT8 + elif inputs[0].dtype == ts.DType.INT16: + input_qparams = get_input_qparams(node) + qargs = input_qparams[0] + pad_const_val = qargs.quantize_value(inputs[2].number).item() + pad_const_dtype = ts.DType.INT16 else: pad_const_val = inputs[2].number pad_const_dtype = inputs[0].dtype diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 8fb789a9d01..7cfd497b1fe 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 5994cbc9c0f..5d6eeb75275 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 859e5c236d7..92879d549b1 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index db2488fa163..5b73b5e91ae 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -6,7 +6,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils # noqa: F401 import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 760e744923c..710b5f8e1d8 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -24,7 +24,6 @@ from torch.fx import Node -@register_node_visitor class CommonIndexTensorVisitor(NodeVisitor): target = "aten.index.Tensor" @@ -165,14 +164,14 @@ def define_node( # channels and thus the stride-shift. data = np.full(index_shape, int(values_strides[i] / C)) mul_const = tosa_graph.addConst(index_shape, index_dtype, data) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift") attr = ts.TosaSerializerAttribute() attr.MulAttribute() self._serialize_operator( node, tosa_graph, ts.Op.MUL, - [index_name, mul_const.name, f"{node.name}_{i}_shift"], + [index_name, mul_const.name, f"{output.name}_{i}_shift"], [stride_shifted_indices.name], attr, ) @@ -186,7 +185,7 @@ def define_node( stride_shifted_indices.name, gather_idx_shape, reshaped_idxs.name, - shape_name_override=f"{node.name}_{i}_shape", + shape_name_override=f"{output.name}_{i}_shape", ) # Guarantees that the accumulation tensor is properly @@ -218,7 +217,7 @@ def define_node( values.name, gather_vals_shape, reshaped_input.name, - shape_name_override=f"{node.name}_index_shape", + shape_name_override=f"{output.name}_index_shape", ) gather_out_shape = (N, W, C) @@ -244,5 +243,5 @@ def define_node( gather_out.name, list(output_shape), output.name, - shape_name_override=f"{node.name}_output_shape", + shape_name_override=f"{output.name}_output_shape", ) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index fb26b5b8606..2b1a023d624 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index f5cf71420f4..4f3e1163c69 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 1cab28f9153..5690b82d97b 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], output.tosa_spec, ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 78b0b1b6675..f1cd5de6fd6 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -48,14 +48,14 @@ def define_node( output.tosa_spec, ) - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift") attr = ts.TosaSerializerAttribute() attr.MulAttribute() self._serialize_operator( node, tosa_graph, ts.Op.MUL, - [inputs[0].name, inputs[1].name, f"{node.name}_shift"], + [inputs[0].name, inputs[1].name, f"{output.name}_shift"], [output.name], attr, ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 21a8f8e1b04..99c0ecce0b2 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -56,7 +56,7 @@ def define_node( (len(multiples),), ts.DType.SHAPE, list(tosa_shape(multiples, output.dim_order)), - name=node.name + "_multiples", + name=output.name + "_multiples", ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index c5510493eae..7366703083c 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -120,7 +120,7 @@ def define_node( (starts_len,), ts.DType.SHAPE, starts, - node.name + "_start_shape", + output.name + "_start_shape", ) sizes = [size if i == dim else shape[i] for i in input_node.dim_order] @@ -130,7 +130,7 @@ def define_node( sizes_len = 1 sizes = [0] sizes_tensor = tosa_graph.addConst( - (sizes_len,), ts.DType.SHAPE, sizes, node.name + "_sizes_shape" + (sizes_len,), ts.DType.SHAPE, sizes, output.name + "_sizes_shape" ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 2281564a0c4..be73a60f7c7 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -72,8 +72,8 @@ def define_node( else: input0_zp, input1_zp = 0, 0 - input_A_ZP_name = f"{node.name}_A_ZP" - input_B_ZP_name = f"{node.name}_B_ZP" + input_A_ZP_name = f"{output.name}_A_ZP" + input_B_ZP_name = f"{output.name}_B_ZP" tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name) tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name) diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index feb7d1ef28a..ae87dcc9c31 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, cast, List +import math +from typing import Any, cast, List, Tuple import torch @@ -19,10 +20,197 @@ from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale from torch.fx import Node +def _compute_multiplier_and_shift( + scales: list[float], scaleWidth: int = 32 +) -> Tuple[list[int], list[int]]: + """Derive integer multipliers and shifts from floating-point scales. + + TOSA uses the RESCALE operation to scale between values with differing + precision. The RESCALE operator is defined using an integer multiply, add, + and shift. This utility function is for calculating the multiplier and shift + given a scale. + Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling + + Args: + scales (list[float]): Scale factors to decompose into multiplier and + shift pairs. + scaleWidth (int): Bit-width of the multiplier representation; expects + ``16`` or ``32``. + + Returns: + Tuple[list[int], list[int]]: Parallel lists containing the computed + multipliers and right shifts. + + Raises: + ValueError: If ``scaleWidth`` is not supported. + + """ + if scaleWidth == 16: + offset = 15 + elif scaleWidth == 32: + offset = 31 + else: + raise ValueError( + f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." + ) + + multipliers = [] + shifts = [] + for scale in scales: + mantissa, exponent = math.frexp(scale) + shift = exponent + + const_2_power_15_or_31 = 1 << offset + shifted_mantissa = round(mantissa * const_2_power_15_or_31) + + assert ( + shifted_mantissa <= const_2_power_15_or_31 + ), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}" + + if shifted_mantissa == const_2_power_15_or_31: + shifted_mantissa = shifted_mantissa // 2 + shift += 1 + + # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. + shift = offset - shift + + # INT32_MAX, 2^31 - 1 + assert shifted_mantissa <= (const_2_power_15_or_31 - 1), ( + f"Mantissa {shifted_mantissa} exceeds signed max " + f"{const_2_power_15_or_31 - 1}" + ) + + multiplier = shifted_mantissa + + if shift > 62: + multiplier = multiplier >> min(31, shift - 62) + shift = 62 + + assert multiplier >= 0, "Multiplier should be non-negative" + assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" + multipliers.append(multiplier) + shifts.append(shift) + return multipliers, shifts + + +def _create_const_ops_for_rescale( + tosa_fb, + scale_32, + input_dtype, + node_name, + multipliers, + shifts, + input_zp, + output_zp, + output_dtype, + ts, +): + """Materialize constant operands required by the TOSA RESCALE op. + + For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp + and output_zp to be const inputs. Create constant operators from the data + already initialized. + + Args: + tosa_fb (Any): Graph builder used to emit TOSA operators and tensors. + scale_32 (bool): Flag indicating whether multipliers use 32-bit width. + input_dtype (ts.DType): Data type of the input tensor. + node_name (str): Base name reused for created constant tensors. + multipliers (list[int]): Precomputed multiplier coefficients. + shifts (list[int]): Precomputed shift coefficients. + input_zp (list[int]): Quantization zero points for the input. + output_zp (list[int]): Quantization zero points for the output. + output_dtype (ts.DType): Data type of the output tensor. + ts (module): Reference to the ``tosa_serializer`` module. + + Returns: + list[str]: Names of the constant tensors added to ``tosa_fb`` in the + order expected by RESCALE. + + """ + + multipliers = tosa_fb.addConst( + (len(multipliers),), + ts.DType.INT32 if scale_32 else ts.DType.INT16, + multipliers, + name=node_name + "_multipliers", + ) + shifts = tosa_fb.addConst( + (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" + ) + input_zp = tosa_fb.addConst( + [1], input_dtype, input_zp, name=node_name + "_input_zp" + ) + output_zp = tosa_fb.addConst( + [1], output_dtype, output_zp, name=node_name + "_output_zp" + ) + + return [multipliers.name, shifts.name, input_zp.name, output_zp.name] + + +def _build_rescale( + tosa_fb: Any, + scale: list[float], + input_node: Any, + output_name: str, + output_type: Any, + input_zp: list[int], + output_zp: list[int], + rounding_mode: ts.RoundingMode, + per_channel: bool = False, + is_scale32: bool = True, +): + """Insert a TOSA RESCALE operator configured for the quantized path. + + Args: + tosa_fb (Any): Graph builder receiving the RESCALE operator. + scale (list[float]): Scale factors applied during rescaling. + input_node (Any): Input tensor node feeding the operator. + output_name (str): Name assigned to the RESCALE output tensor. + output_type (ts.DType): Data type of the output tensor. + input_zp (list[int]): Quantization zero points for the input tensor. + output_zp (list[int]): Quantization zero points for the output tensor. + rounding_mode (ts.RoundingMode): Rounding policy for the RESCALE op. + per_channel (bool): Whether scales are applied per output channel. + is_scale32 (bool): Declared scale width; ignored when the input type is + ``ts.DType.INT48``. + + """ + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 + is_scale32 = False if input_node.dtype == ts.DType.INT48 else True + multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) + rescale_inputs = _create_const_ops_for_rescale( + tosa_fb, + is_scale32, + input_node.dtype, + output_name, + multipliers, + shifts, + input_zp, + output_zp, + output_type, + ts, + ) + attr_rescale = ts.TosaSerializerAttribute() + attr_rescale.RescaleAttribute( + scale32=is_scale32, + rounding_mode=rounding_mode, + per_channel=per_channel, + input_unsigned=False, + output_unsigned=False, + ) + + tosa_fb.addOperator( + ts.Op.RESCALE, + [input_node.name, *rescale_inputs], + [output_name], + attr_rescale, + ) + + @register_node_visitor class RescaleVisitor(NodeVisitor): target = "tosa.RESCALE.default" @@ -60,7 +248,7 @@ def define_node( f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) - build_rescale( + _build_rescale( tosa_graph, scale=scales, input_node=inputs[0], diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index fb8e305839f..6e6edf4fd41 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -46,13 +46,27 @@ def define_node( resize_mode = ts.ResizeMode.NEAREST align_corners = False validate_same_dtype(self.target, [inputs[0], output], ts) + + valid_dtypes = [] + if self.tosa_spec.support_integer(): + valid_dtypes.extend( + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.INT48] + ) + + if self.tosa_spec.support_float(): + valid_dtypes.extend( + [ + ts.DType.FP16, + ts.DType.FP32, + ] + ) + validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP16, ts.DType.FP32], + valid_dtypes, output.tosa_spec, ) - # tosa_shape output is NHWC, take HW input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ 1:3 @@ -84,15 +98,15 @@ def in_int16_range(x): scale_d_vals[1], ] scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" + [len(scales)], ts.DType.SHAPE, scales, output.name + "_scales" ) offset = [int(v) for v in offset_yx.tolist()] offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" + [len(offset)], ts.DType.SHAPE, offset, output.name + "_offset" ) border = [int(v) for v in border_yx.tolist()] border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" + [len(border)], ts.DType.SHAPE, border, output.name + "_border" ) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute(resize_mode) diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 11407517b6a..d867b5efd7b 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -44,27 +44,24 @@ def define_node( if inputs[0].dtype == ts.DType.INT16: validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec) - if inputs[1].name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] + # The name of the table constant is a bit complex. + # The name of the pytorch buffer will be the target of last node argument. + # However, when it is serialized to TOSA, a submodule suffix might be added. The TOSA buffer name thus + # needs to be taken from the last TosaArg. + pytorch_table_buffer_name = node.args[-1].target # type: ignore[union-attr] + tosa_table_buffer_name = inputs[-1].name + if pytorch_table_buffer_name not in self._exported_program.state_dict.keys(): raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." ) - table = self._exported_program.state_dict[inputs[1].name] # type: ignore[union-attr] - - table_tensor_name = node.name + "_table" - tosa_graph.addConst( - table.shape, - ts.DType.INT8 if inputs[0].dtype == ts.DType.INT8 else ts.DType.INT16, - table.detach().numpy(), - name=table_tensor_name, - ) attr = ts.TosaSerializerAttribute() attr.TableAttribute() self._serialize_operator( node, tosa_graph, ts.Op.TABLE, - [inputs[0].name, table_tensor_name], + [inputs[0].name, tosa_table_buffer_name], [output.name], attr, ) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index f13c386a5ee..a32cb3aac06 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -66,7 +66,7 @@ def define_node( shape_len, ts.DType.SHAPE, shape_data, - name=node.name + "_shape", + name=output.name + "_shape", ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py new file mode 100644 index 00000000000..e6977474b89 --- /dev/null +++ b/backends/arm/operators/op_while.py @@ -0,0 +1,74 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, cast, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_cf_extension, + validate_num_inputs, +) +from executorch.backends.arm.tosa.mapping import TosaArg +from torch.fx import Node + + +@register_node_visitor +class WhileLoopVisitor(NodeVisitor): + target = "while_loop" + + tosa_specs = NodeVisitor.tosa_specs + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + validate_num_inputs(self.target, inputs, 4) + validate_cf_extension(self.target, self.tosa_spec) + + carried_inputs = inputs[2].special if hasattr(inputs[2], "special") else None + if carried_inputs is None: + raise ValueError(f"{self.target}: Expected loop input arguments to be set.") + + additional_inputs = inputs[3].special if hasattr(inputs[3], "special") else None + if additional_inputs: + raise ValueError( + "Additional inputs is not supported, use carried inputs instead." + ) + + attr = ts.TosaSerializerAttribute() + cond_graph, body_graph = (cast(Node, arg).target for arg in node.args[:2]) + attr.WhileLoopAttribute(cond_graph, body_graph) + + input_names: list[str] = [] + for loop_input in carried_inputs: + if not isinstance(loop_input, Node): + raise ValueError( + f"{self.target}: Unsupported carried input type {type(loop_input)}." + ) + input_names.append(loop_input.name) + + if len(input_names) != len(output.multiple_output_names): + raise ValueError( + f"TOSA specifies that the number of inputs, {input_names}, need to be the " + f"same as the number of outputs, {output.multiple_output_names}." + ) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.WHILE_LOOP, + input_names, + output.multiple_output_names, + attr, + ) diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 32c01143f4f..20ee10534d0 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -12,6 +12,8 @@ from math import ceil, floor from typing import Any, List, Optional +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification + def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): """Validate the number of inputs against expected values. @@ -150,6 +152,19 @@ def validate_valid_dtype( ) +def validate_cf_extension(op_name: str, tosa_spec: TosaSpecification) -> None: + """Ensure that the requested control-flow operator is supported by the active TOSA spec.""" + if not isinstance(tosa_spec, Tosa_1_00): + raise ValueError( + f"Got TOSA version {tosa_spec.version}, that does not support extensions." + ) + if not tosa_spec.support_extension("cf"): + raise ValueError( + f"Trying to lower {op_name}, but TOSA specification {tosa_spec} does not " + "support the cf extension." + ) + + def adjust_pooling_pad_if_needed( input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool ) -> int: diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index d570c52ed31..c4a6d78fef4 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -41,7 +41,7 @@ def define_node( output: TosaArg, ) -> None: validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) + validate_same_dtype(self.target, [inputs[0], output], ts) # Simply add an identityOp attr = ts.TosaSerializerAttribute() @@ -58,5 +58,4 @@ def define_node( register_node_visitor(IdentityOperatorVisitor) -identity_operator_factory("getitem") identity_operator_factory("aten.alias_copy.default") diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index f9694c1abdf..989e0253a65 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # +import operator from typing import Any, cast, Dict import numpy as np @@ -14,6 +15,7 @@ from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape +from executorch.exir.graph_module import get_cond_while_submodules from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, @@ -45,14 +47,18 @@ def process_call_function( f"Failed processing call_function: {node.name}. " "Is the original torch function supported?" ) from e - tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, tosa_shape(output.shape, output.dim_order), output.dtype - ) + + if not output.multiple_output_names: + tosa_graph.currRegion.currBasicBlock.addTensor( + output.name, tosa_shape(output.shape, output.dim_order), output.dtype + ) + + # Get item nodes just add tensors, no node visitor is needed. + if node.target == operator.getitem: + return # Visiting each Node - # pyre-ignore[16]: Undefined attribute. if node.target.__name__ in node_visitors: # type: ignore[union-attr] - # pyre-ignore[16]: Undefined attribute. node_visitors[node.target.__name__].define_node( # type: ignore[union-attr] node, tosa_graph, @@ -183,10 +189,59 @@ def process_inputs_to_lifted_tensor_constants( ) +def _is_submodule_input( + node: torch.fx.Node, containing_graph_module: torch.fx.GraphModule +) -> bool: + """Determines whether 'node' is an input to a submodule of 'containing_graph_module'.""" + if node.op != "placeholder": + return False + + for _, _, submodule_node in get_cond_while_submodules(containing_graph_module): + args = cast(list[torch.fx.Node], submodule_node.args[-1]) + for arg in args: + if isinstance(arg.target, str): + # If argument is a buffer or similar, we can match exactly. + if arg.target == node.name: + return True + # If argument target has a name, the submodule input is operator name + number to avoid duplication. + # For example: cond input namespace::my_op -> submodule input my_op_1 + if (name_fn := (getattr(arg.target, "name", None))) is not None: + op_name = name_fn().split(":")[-1] + if op_name in node.name: + return True + return False + + +def _submodule_has_user_input( + containing_graph_module: torch.fx.GraphModule, edge_program: ExportedProgram +): + # If argument is a user input, there is no such guarantee. We need to to a heuristic match. + for _, _, control_flow_node in get_cond_while_submodules(containing_graph_module): + match control_flow_node.target: + case torch.ops.higher_order.cond: + args = control_flow_node.args[-1] + case torch.ops.higher_order.while_loop: + args = cast(list, control_flow_node.args[-2]) + cast( + list, control_flow_node.args[-1] + ) + case _: + raise RuntimeError( + f"Unexpected control flow target: {control_flow_node.target}" + ) + args = cast(list[torch.fx.Node], args) + for arg in args: + if ( + isinstance(arg.target, str) + and arg.target in edge_program.graph_signature.user_inputs + ): + return True + + def process_placeholder( node: torch.fx.Node, tosa_graph: Any, edge_program: ExportedProgram, + containing_graph_module: torch.fx.GraphModule | None, tosa_spec: TosaSpecification, ): """Wrapper for processing and serializing all types of placeholders""" @@ -199,6 +254,8 @@ def process_placeholder( if node.name in edge_program.graph_signature.user_inputs: process_inputs(node, tosa_graph, tosa_spec) + elif containing_graph_module and _is_submodule_input(node, containing_graph_module): + process_inputs(node, tosa_graph, tosa_spec) elif is_param(edge_program, node): process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) elif is_buffer(edge_program, node): @@ -211,6 +268,11 @@ def process_placeholder( raise NotImplementedError( "Placeholder is of type 'lifted custom object' which is not supported." ) + elif containing_graph_module and _submodule_has_user_input( + containing_graph_module, edge_program + ): + # If we are in a submodule and it has user input, process as regular input. + process_inputs(node, tosa_graph, tosa_spec) else: raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") diff --git a/backends/arm/quantizer/__init__.py b/backends/arm/quantizer/__init__.py index e36c683416a..2018b845353 100644 --- a/backends/arm/quantizer/__init__.py +++ b/backends/arm/quantizer/__init__.py @@ -12,6 +12,7 @@ from .quantization_config import QuantizationConfig # noqa # usort: skip from .arm_quantizer import ( # noqa EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, TOSAQuantizer, VgfQuantizer, diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e52b30895dc..f40f3a610fa 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -24,6 +24,7 @@ ArmCompileSpec, ) # isort: skip from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.exir.graph_module import get_cond_while_submodules from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( @@ -36,10 +37,16 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) from torchao.quantization.pt2e.quantizer import ( annotate_input_qspec_map, annotate_output_qspec, + get_module_name_filter, QuantizationSpec, Quantizer, ) @@ -51,6 +58,7 @@ "TOSAQuantizer", "EthosUQuantizer", "VgfQuantizer", + "get_symmetric_a16w8_quantization_config", "get_symmetric_quantization_config", ] @@ -201,7 +209,7 @@ def get_symmetric_a16w8_quantization_config( # 16-bit activation quantization spec act_quantization_spec = QuantizationSpec( dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min, # -32768 + quant_min=torch.iinfo(torch.int16).min + 1, # -32767 quant_max=torch.iinfo(torch.int16).max, # 32767 qscheme=torch.per_tensor_symmetric, is_dynamic=is_dynamic, @@ -241,33 +249,6 @@ def get_symmetric_a16w8_quantization_config( """ -def _get_module_name_filter(module_name: str) -> NodeFilterType: - """Get the module_name_filter function for a given module name, the filter accepts - a node and checks if the node comes from a module that has certain module name - - For example: - node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 - - >> module_name_filter = _get_module_name_filter("blocks.sub") - >> print(module_name_filter(node)) - True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" - """ - - name_start = len("L['self'].") - - def module_name_filter(n: Node) -> bool: - # node_stack example: { - # 'L__self___sub': ("L['self'].sub", ), - # 'L__self___sub_linear': ("L['self'].sub.linear", ) - # } - # get_attr nodes doesn't have nn_module_stack? - nn_module_stack = n.meta.get("nn_module_stack", {}) - names = [name[name_start:] for name, _ in nn_module_stack.values()] - return module_name in names - - return module_name_filter - - def _get_module_type_filter(tp: Callable) -> NodeFilterType: """Get the module_type_filter function for a given module type, the filter accepts a node and checks if the node comes from a module that has certain module type @@ -299,7 +280,7 @@ def _get_not_module_type_or_name_filter( tp_list: List[Callable], module_name_list: List[str] ) -> NodeFilterType: module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: return not any(f(n) for f in module_type_filters + module_name_list_filters) @@ -448,7 +429,7 @@ def _annotate_for_static_quantization_config( module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_static_patterns( - model, config, _get_module_name_filter(module_name) + model, config, get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) @@ -480,15 +461,42 @@ def _annotate_io( ) mark_node_as_annotated(node) if node.op == "output": - parent = node.all_input_nodes[0] - annotate_input_qspec_map( - node, parent, quantization_config.get_input_act_qspec() - ) + for parent in node.all_input_nodes: + annotate_input_qspec_map( + node, parent, quantization_config.get_input_act_qspec() + ) mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: pass + def quantize_with_submodules( + self, + model: GraphModule, + calibration_samples: list[tuple], + is_qat: bool = False, + ): + """Quantizes a GraphModule in a way such that conditional submodules are handled properly. + + Args: + model: GraphModule, the model to quantize. + calibration_samples: list[tuple], a list of inputs to used to calibrate the model during quantization. + To properly calibrate a model with submodules, at least one sample per code path is needed. + is_qat: bool, whether to do quantization aware training or not. + """ + prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e + + prepared = prepare_fn(model, self) + for name, submodule, _ in get_cond_while_submodules(prepared): + prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) + for inp in calibration_samples: + prepared(*inp) + + for name, submodule, _ in get_cond_while_submodules(prepared): + prepared.set_submodule(name, convert_pt2e(submodule), strict=True) + converted = convert_pt2e(prepared) + return converted + class EthosUQuantizer(TOSAQuantizer): """ diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index c1137ea4149..7bd8e00c22b 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,6 +14,8 @@ from typing import cast +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo + from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -65,4 +67,10 @@ def mark_node_as_annotated(node: Node) -> None: """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + annotation_info = ArmAnnotationInfo( + quantized=True, + ) node.meta[Q_ANNOTATION_KEY]._annotated = True + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) + node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index dc3beb5370a..1d20e1db4fe 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -75,7 +75,7 @@ def _as_list(x): list: ``x`` if already a list; otherwise ``[x]``. """ - if isinstance(x, list): + if isinstance(x, (list, tuple)): return x else: return [ @@ -394,6 +394,7 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, @@ -426,6 +427,7 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ + torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -693,10 +695,12 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = None elif node.target in [ - torch.ops.aten.scalar_tensor.default, torch.ops.aten.full.default, torch.ops.aten.full, + torch.ops.aten.zeros.default, + torch.ops.aten.ones.default, torch.ops.aten.fill_.Scalar, + torch.ops.aten.scalar_tensor.default, ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) @@ -707,6 +711,28 @@ def any_or_hardtanh_min_zero(n: Node): shared_qspec = SharedQuantizationSpec(input_node) quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target in ( + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + ): + submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2 + submodule_args = node.args[submodule_args_pos] + if len(submodule_args) > 0: # type: ignore[arg-type] + # The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a + # conditional graph) need shared quantization. + shared_qspec = SharedQuantizationSpec( + (cast(list[Node], submodule_args)[0], node) + ) + quant_properties.quant_inputs = [ + _QuantProperty( + submodule_args_pos, + [ + input_act_qspec, + *([shared_qspec] * (len(submodule_args) - 1)), # type: ignore[arg-type] + ], + ) + ] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) else: return None @@ -772,5 +798,7 @@ def annotate_graph( # type: ignore[return] torch.ops.aten.full, torch.ops.aten.fill_.Scalar, torch.ops.aten.scalar_tensor.default, + torch.ops.aten.zeros.default, + torch.ops.aten.ones.default, ]: node.kwargs = {} diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 36ab233bdb6..3e2939cff61 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -206,8 +206,8 @@ def _derive_qparams_fn( derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] derive_qparams_fn=_derive_qparams_fn, dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max - 1, + quant_min=torch.iinfo(torch.int32).min + 1, + quant_max=torch.iinfo(torch.int32).max, qscheme=qscheme, ch_axis=ch_axis, ) diff --git a/backends/arm/requirements-arm-models-test.txt b/backends/arm/requirements-arm-models-test.txt index ac4e1d9bad7..238e9d07c9d 100644 --- a/backends/arm/requirements-arm-models-test.txt +++ b/backends/arm/requirements-arm-models-test.txt @@ -3,4 +3,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -diffusers[torch] == 0.33.1 +diffusers[torch] == 0.33.1 \ No newline at end of file diff --git a/backends/arm/requirements-arm-tosa.txt b/backends/arm/requirements-arm-tosa.txt index da115441c52..c93e9411647 100644 --- a/backends/arm/requirements-arm-tosa.txt +++ b/backends/arm/requirements-arm-tosa.txt @@ -5,5 +5,5 @@ ml_dtypes == 0.5.1 flatbuffers == 24.3.25 -tosa-adapter-model-explorer == 0.0.1 +tosa-adapter-model-explorer == 0.1.0 ai-edge-model-explorer >= 0.1.16 diff --git a/backends/arm/requirements-arm-vgf.txt b/backends/arm/requirements-arm-vgf.txt new file mode 100644 index 00000000000..1bf4d78c995 --- /dev/null +++ b/backends/arm/requirements-arm-vgf.txt @@ -0,0 +1,8 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +ai_ml_emulation_layer_for_vulkan == 0.7.0 +ai_ml_sdk_model_converter == 0.7.0 +ai_ml_sdk_vgf_library == 0.7.0 diff --git a/backends/arm/runtime/VGFSetup.cpp b/backends/arm/runtime/VGFSetup.cpp index fa8c7ead220..fd3a114c190 100644 --- a/backends/arm/runtime/VGFSetup.cpp +++ b/backends/arm/runtime/VGFSetup.cpp @@ -707,7 +707,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { ); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to create DataGraphPipeline"); - return result; + return false; } // prepare the graph pipeline session @@ -721,7 +721,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vk_device, &pipeline_session_info, nullptr, &vk_session); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to create DataGraphPipelineSession"); - return result; + return false; } // Allocate command buffer @@ -735,7 +735,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vk_device, &buffer_allocate_info, &vk_execute_cmd); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to allocate command buffers"); - return result; + return false; } // Allocate intermediates memory based on the pipeline requirements provided @@ -753,7 +753,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vk_device, &bind_point_requirements_info, &bind_point_count, nullptr); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to get session bind point count"); - return result; + return false; } vector @@ -766,7 +766,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { bind_point_requirements.data()); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to get session bind point requirements"); - return result; + return false; } // Given the bind points, just make individual allocations and bind them @@ -777,18 +777,18 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { ET_LOG( Error, "Expected VK_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_TYPE_MEMORY_ARM"); - return VK_ERROR_UNKNOWN; + return false; } if (bind_point_requirement.bindPoint != VK_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_TRANSIENT_ARM) { ET_LOG( Error, "Expected VK_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_TRANSIENT_ARM"); - return VK_ERROR_UNKNOWN; + return false; } if (bind_point_requirement.numObjects != 1) { ET_LOG(Error, "Expected only one object for the bindpoint"); - return VK_ERROR_UNKNOWN; + return false; } VkDataGraphPipelineSessionMemoryRequirementsInfoARM memory_requirements_info = { @@ -821,7 +821,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { vkAllocateMemory(vk_device, &memory_allocate_info, nullptr, &memory); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to allocate memory for intermediates"); - return result; + return false; } // so we can free this object in destructor intermediates.push_back(memory); @@ -839,7 +839,7 @@ bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef specs) { result = vkBindDataGraphPipelineSessionMemoryARM(vk_device, 1, &bind_info); if (result != VK_SUCCESS) { ET_LOG(Error, "Failed to bind intermediates memory"); - return result; + return false; } } diff --git a/backends/arm/scripts/build_executor_runner_vkml.sh b/backends/arm/scripts/build_executor_runner_vkml.sh index 61edf3fbbe4..16074bc8ead 100755 --- a/backends/arm/scripts/build_executor_runner_vkml.sh +++ b/backends/arm/scripts/build_executor_runner_vkml.sh @@ -6,39 +6,43 @@ set -eu -script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) et_root_dir=$(cd ${script_dir}/../../.. && pwd) et_root_dir=$(realpath ${et_root_dir}) setup_path_script=${et_root_dir}/examples/arm/ethos-u-scratch/setup_path.sh _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly install necessary tools." -source "${script_dir}/utils.sh" - build_type="Release" build_with_etdump=false extra_build_flags="" output_folder="cmake-out-vkml" +build_with_etdump_flags="-DEXECUTORCH_ENABLE_EVENT_TRACER=OFF" +build_with_bundleio_flags="-DEXECUTORCH_ENABLE_BUNDLE_IO=OFF" + +source "${script_dir}/utils.sh" + -build_with_etdump_flags="-DEXECUTORCH_ENABLE_EVENT_TRACER=OFF -DEXECUTORCH_BUILD_DEVTOOLS=OFF" help() { echo "Usage: $(basename $0) [options]" echo "Options:" - echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" - echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" - echo " --extra_build_flags= Extra flags to pass to cmake. Default: none " - echo " --output= Output folder Default: $(output_folder)" + echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" + echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --extra_build_flags= Extra flags to pass to cmake. Default: none " + echo " --output= Output folder Default: $(output_folder)" + echo " --bundleio Support BundleIO using Devtools with Input/RefOutput included" exit 0 } for arg in "$@"; do case $arg in - -h|--help) help ;; - --build_type=*) build_type="${arg#*=}";; - --etdump) build_with_etdump=true ;; - --extra_build_flags=*) extra_build_flags="${arg#*=}";; - --output=*) output_folder="${arg#*=}";; - --select_ops_list=*) select_ops_list="${arg#*=}";; - *) - ;; + -h|--help) help ;; + --build_type=*) build_type="${arg#*=}";; + --etdump) build_with_etdump=true ;; + --extra_build_flags=*) extra_build_flags="${arg#*=}";; + --output=*) output_folder="${arg#*=}";; + --select_ops_list=*) select_ops_list="${arg#*=}";; + --bundleio) build_with_bundleio_flags="-DEXECUTORCH_ENABLE_BUNDLE_IO=ON" ;; + *) + ;; esac done @@ -52,23 +56,24 @@ source ${setup_path_script} mkdir -p "${output_folder}" output_folder=$(realpath "${output_folder}") -echo "--------------------------------------------------------------------------------" -echo "Build Arm VKML executor runner: '${output_folder}' with extra build flags: ${extra_build_flags}" -echo "--------------------------------------------------------------------------------" - cd ${et_root_dir}/examples/arm/executor_runner if [ "$build_with_etdump" = true ] ; then - build_with_etdump_flags="-DEXECUTORCH_ENABLE_EVENT_TRACER=ON -DEXECUTORCH_BUILD_DEVTOOLS=ON" + build_with_etdump_flags="-DEXECUTORCH_ENABLE_EVENT_TRACER=ON" fi -echo "Building with extra flags: ${build_with_etdump_flags} ${extra_build_flags}" +echo "-----------------------------------------------------------------------------------------------" +echo "Build Arm VKML executor runner: '${output_folder}' with extra build flags: " +echo "${build_with_etdump_flags} ${build_with_bundleio_flags} ${extra_build_flags}" +echo "-----------------------------------------------------------------------------------------------" + cmake \ -S "${et_root_dir}" \ -B "${output_folder}" \ -Wall \ -Werror \ -DCMAKE_BUILD_TYPE=${build_type} \ + -DCMAKE_CXX_FLAGS="${extra_build_flags} ${CMAKE_CXX_FLAGS:-}" \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ @@ -80,9 +85,10 @@ cmake \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DEXECUTORCH_BUILD_DEVTOOLS=ON \ -DPYTHON_EXECUTABLE="$(which python3)" \ - ${build_with_etdump_flags} \ - ${extra_build_flags} + ${build_with_etdump_flags} \ + ${build_with_bundleio_flags} echo "[${BASH_SOURCE[0]}] Configured CMAKE" diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index e4cc02d20c6..8597154971a 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. # Optional parameter: -# --build_type= "Release" | "Debug" | "RelWithDebInfo" +# --build_type= "Release" | "Debug" | "RelWithDebInfo" | "UndefinedSanitizer" # --etdump build with devtools-etdump support set -eu @@ -28,7 +28,7 @@ help() { echo "Usage: $(basename $0) [options]" echo "Options:" echo " --et_build_root= Build output root folder to use, defaults to ${et_build_root}" - echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" + echo " --build_type= Build with Release, Debug, RelWithDebInfo or UndefinedSanitizer, default is ${build_type}" echo " --devtools Build Devtools libs" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" echo " --toolchain= Toolchain can be specified (e.g. bare metal as arm-none-eabi-gcc or zephyr as arm-zephyr-eabi-gcc Default: ${toolchain}" @@ -78,7 +78,7 @@ cd "${et_root_dir}" # Build cmake -DCMAKE_TOOLCHAIN_FILE=${toolchain_cmake} \ --DCMAKE_BUILD_TYPE=Release \ +-DCMAKE_BUILD_TYPE=${build_type} \ -DEXECUTORCH_BUILD_DEVTOOLS=$build_devtools \ -DEXECUTORCH_BUILD_ARM_ETDUMP=$build_with_etdump \ --preset arm-baremetal -B${et_build_dir} diff --git a/backends/arm/scripts/install_models_for_test.sh b/backends/arm/scripts/install_models_for_test.sh index 001d733a014..d6a7b9cdec0 100644 --- a/backends/arm/scripts/install_models_for_test.sh +++ b/backends/arm/scripts/install_models_for_test.sh @@ -6,3 +6,16 @@ set -e pip install -r backends/arm/requirements-arm-models-test.txt + +# Install model gym repository +git clone https://github.com/arm/neural-graphics-model-gym.git +cd neural-graphics-model-gym +# Remove model-converter installation from model-gym repository (to prevent overwriting executorch version) +if [[ "$(uname)" == "Darwin" ]]; then + sed -i '' 's/^model-converter = "ng_model_gym.bin.model_converter_launcher:main"/#&/' pyproject.toml +else + sed -i 's/^model-converter = "ng_model_gym.bin.model_converter_launcher:main"/#&/' pyproject.toml +fi +pip install . --no-deps +cd .. +rm -rf neural-graphics-model-gym \ No newline at end of file diff --git a/backends/arm/scripts/mlsdk_utils.sh b/backends/arm/scripts/mlsdk_utils.sh index 2257bc674ca..95aa5cf2a4f 100755 --- a/backends/arm/scripts/mlsdk_utils.sh +++ b/backends/arm/scripts/mlsdk_utils.sh @@ -205,7 +205,51 @@ function setup_path_emulation_layer() { model_emulation_layer_path="$(cd "${mlsdk_manifest_dir}/sw/emulation-layer/" && pwd)" prepend_env_in_setup_path LD_LIBRARY_PATH "${model_emulation_layer_path}/deploy/lib" prepend_env_in_setup_path DYLD_LIBRARY_PATH "${model_emulation_layer_path}/deploy/lib" + prepend_env_in_setup_path VK_LAYER_PATH "${model_emulation_layer_path}/deploy/share/vulkan/explicit_layer.d" prepend_env_in_setup_path VK_INSTANCE_LAYERS VK_LAYER_ML_Tensor_Emulation prepend_env_in_setup_path VK_INSTANCE_LAYERS VK_LAYER_ML_Graph_Emulation - prepend_env_in_setup_path VK_LAYER_PATH "${model_emulation_layer_path}/deploy/share/vulkan/explicit_layer.d" +} + +function setup_path_emulation_layer_from_pip() { + if ! command -v emulation_layer >/dev/null 2>&1; then + echo "[mlsdk_utils] 'emulation_layer' command not found; skipping pip emulation layer path setup" + return + fi + + local output + if ! output=$(emulation_layer 2>/dev/null); then + echo "[mlsdk_utils] Failed to query emulation_layer environment; skipping" + return + fi + + local exports + exports=$(echo "$output" | grep '^export ' || true) + + local ld_line + ld_line=$(echo "$exports" | grep 'LD_LIBRARY_PATH=' || true) + if [[ -n "${ld_line}" ]]; then + local ld_value=${ld_line#export LD_LIBRARY_PATH=} + ld_value=${ld_value%%:\$LD_LIBRARY_PATH*} + if [[ -n "${ld_value}" ]]; then + prepend_env_in_setup_path LD_LIBRARY_PATH "${ld_value}" + fi + fi + + local vk_add_line + vk_add_line=$(echo "$exports" | grep 'VK_ADD_LAYER_PATH=' || true) + if [[ -n "${vk_add_line}" ]]; then + local vk_add_value=${vk_add_line#export VK_ADD_LAYER_PATH=} + if [[ -n "${vk_add_value}" ]]; then + prepend_env_in_setup_path VK_ADD_LAYER_PATH "${vk_add_value}" + fi + fi + + local vk_instance_line + vk_instance_line=$(echo "$exports" | grep 'VK_INSTANCE_LAYERS=' || true) + if [[ -n "${vk_instance_line}" ]]; then + local vk_instance_value=${vk_instance_line#export VK_INSTANCE_LAYERS=} + if [[ -n "${vk_instance_value}" ]]; then + prepend_env_in_setup_path VK_INSTANCE_LAYERS "${vk_instance_value}" + fi + fi } diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index a663ba2e8b7..1315358b40b 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -7,6 +7,7 @@ # Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. CUSTOM_EDGE_OPS = [ "linspace.default", + "cond.default", "eye.default", "expm1.default", "vector_norm.default", @@ -18,7 +19,9 @@ "multihead_attention.default", "adaptive_avg_pool2d.default", "bitwise_right_shift.Tensor", + "bitwise_right_shift.Scalar", "bitwise_left_shift.Tensor", + "bitwise_left_shift.Scalar", "native_group_norm.default", "silu.default", "sdpa.default", @@ -30,6 +33,7 @@ "alias_copy.default", "pixel_shuffle.default", "pixel_unshuffle.default", + "while_loop.default", ] ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS diff --git a/backends/arm/scripts/run_vkml.sh b/backends/arm/scripts/run_vkml.sh index 8a64a937638..bb2d5844642 100755 --- a/backends/arm/scripts/run_vkml.sh +++ b/backends/arm/scripts/run_vkml.sh @@ -19,6 +19,7 @@ _setup_msg="please refer to ${et_root_dir}/examples/arm/setup.sh to properly ins model="" +opt_flags="" build_path="cmake-out-vkml" converter="model-converter" @@ -33,6 +34,7 @@ help() { for arg in "$@"; do case $arg in -h|--help) help ;; + --optional_flags=*) opt_flags="${arg#*=}";; --model=*) model="${arg#*=}";; --build_path=*) build_path="${arg#*=}";; *) @@ -50,16 +52,21 @@ if [[ -z ${model} ]]; then echo "Model name needs to be provided"; exit 1; fi source ${setup_path_script} -# basic checks before we get started -hash ${converter} \ - || { echo "Could not find ${converter} on PATH, ${_setup_msg}"; exit 1; } +if ! command -v "${converter}" >/dev/null 2>&1; then + if command -v model_converter >/dev/null 2>&1; then + converter="model_converter" + fi +fi + +command -v "${converter}" >/dev/null 2>&1 \ + || { echo "Could not find a model converter executable (tried model-converter, model_converter). ${_setup_msg}"; exit 1; } +runner=$(find ${build_path} -name executor_runner -type f) -runner="${build_path}/executor_runner" echo "--------------------------------------------------------------------------------" -echo "Running ${model} with ${runner}" +echo "Running ${model} with ${runner} ${opt_flags}" echo "WARNING: The VK_ML layer driver will not provide accurate performance information" echo "--------------------------------------------------------------------------------" @@ -75,7 +82,7 @@ fi log_file=$(mktemp) -${nobuf} ${runner} -model_path ${model} | tee ${log_file} +${nobuf} ${runner} -model_path ${model} ${opt_flags} | tee ${log_file} echo "[${BASH_SOURCE[0]}] execution complete, $?" # Most of these can happen for bare metal or linx executor_runner runs. diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index 8a08c74efc4..db07824432f 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -29,7 +29,7 @@ def pytest_configure(config): if config.option.arm_run_tosa_version: pytest._test_options["tosa_version"] = config.option.arm_run_tosa_version - logging.basicConfig(level=logging.INFO, stream=sys.stdout) + logging.basicConfig(stream=sys.stdout) def pytest_collection_modifyitems(config, items): diff --git a/backends/arm/test/misc/test_call_operator_submodule.py b/backends/arm/test/misc/test_call_operator_submodule.py new file mode 100644 index 00000000000..799c546e24e --- /dev/null +++ b/backends/arm/test/misc/test_call_operator_submodule.py @@ -0,0 +1,72 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import torch + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager +from executorch.backends.arm.tosa.specification import TosaSpecification +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult + + +class _DepthRecordingPass(ArmPass): + _passes_required_after = set() + + def __init__(self, initial_graph_module): + super().__init__() + self.depths: list[int] = [] + self.initial_submodule = initial_graph_module + self.submodule = None + self.num_submodules_called = 0 + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + """Should only be called from the top-level graph module.""" + self.depths.append(self.submodule_depth) + assert self.submodule == self.initial_submodule + return super().call_operator(op, args, kwargs, meta, updated) + + def call_submodule( + self, graph_module: GraphModule, inputs: tuple[Any, ...] + ) -> PassResult: + """Should be called for all three graph_modules: top-level, if, and else.""" + self.submodule = graph_module + self.num_submodules_called += 1 + return super().call_submodule(graph_module, inputs) + + +class _CondModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def _true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1 + + def _false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1 + + predicate = x.sum() > 0 + return torch.cond(predicate, _true_branch, _false_branch, [x]) + + +def test_call_operator_runs_once_for_cond_submodules() -> None: + module = _CondModule() + example_inputs = (torch.randn(2, 3),) + exported = torch.export.export(module, example_inputs) + graph_module = exported.graph_module + + recording_pass = _DepthRecordingPass(graph_module) + pass_manager = ArmPassManager(TosaSpecification.create_from_string("TOSA-1.00+FP")) + pass_manager.add_pass(recording_pass) + pass_manager._transform(graph_module) + + assert recording_pass.num_submodules_called == 3 + assert recording_pass.depths, "call_operator was never invoked" + assert ( + max(recording_pass.depths) == 1 + ), "call_operator was invoked with larger than one submodule depth." + assert ( + min(recording_pass.depths) == 1 + ), "call_operator was invoked with zero submodule depth." diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 40dccc4197e..6e961457db4 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -23,7 +23,6 @@ ) from executorch.backends.test.harness.stages import StageType - input_t1 = Tuple[torch.Tensor] # Input x @@ -261,14 +260,14 @@ def test_dump_tosa_debug_tosa(test_data: input_t1): @common.parametrize("test_data", Linear.inputs) -def test_dump_tosa_ops(caplog, test_data: input_t1): +def test_dump_tosa_ops(capsys, test_data: input_t1): aten_ops: list[str] = [] exir_ops: list[str] = [] pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops) pipeline.pop_stage("run_method_and_compare_outputs") pipeline.dump_operator_distribution("to_edge_transform_and_lower") pipeline.run() - assert "TOSA operators:" in caplog.text + assert "TOSA operators:" in capsys.readouterr().out class Add(torch.nn.Module): @@ -282,7 +281,7 @@ def forward(self, x): @common.parametrize("test_data", Add.inputs) @common.XfailIfNoCorstone300 -def test_fail_dump_tosa_ops(caplog, test_data: input_t1): +def test_fail_dump_tosa_ops(capsys, test_data: input_t1): aten_ops: list[str] = [] exir_ops: list[str] = [] pipeline = EthosU55PipelineINT[input_t1]( @@ -290,4 +289,7 @@ def test_fail_dump_tosa_ops(caplog, test_data: input_t1): ) pipeline.dump_operator_distribution("to_edge_transform_and_lower") pipeline.run() - assert "Can not get operator distribution for Vela command stream." in caplog.text + assert ( + "Can not get operator distribution for Vela command stream." + in capsys.readouterr().out + ) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index d6d6d6cb39c..46a97fff1df 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), - "int64_in+float_const": ( - ConstAdd(torch.float32), - (torch.randint(0, 10, (10,)),), - ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + common.get_tosa_compile_spec("TOSA-1.0+FP"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_outputs_order.py b/backends/arm/test/misc/test_outputs_order.py index cada9e89922..253888537f8 100644 --- a/backends/arm/test/misc/test_outputs_order.py +++ b/backends/arm/test/misc/test_outputs_order.py @@ -78,14 +78,18 @@ def _read_tosa_outputs(tosa_path: Path): return shapes +# TODO: MLETORCH-1266 Investigate output order issue @pytest.mark.parametrize("batch_size", [1, 4]) -def test_network_output_order_and_restore(batch_size): +@pytest.mark.parametrize("output_order_workaround", [True, False]) +def test_network_output_order_and_restore(batch_size, output_order_workaround): model = Network(batch_norm=True).eval() # Prepare spec spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - compile_spec = TosaCompileSpec(tosa_spec=spec) + tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround( + output_order_workaround + ) # Setup quantizer - quantizer = TOSAQuantizer(compile_spec) + quantizer = TOSAQuantizer(tosa_compile_spec) quantizer.set_global( get_symmetric_quantization_config(is_qat=True, is_per_channel=False) ) @@ -100,7 +104,7 @@ def test_network_output_order_and_restore(batch_size): with tempfile.TemporaryDirectory(dir="") as tmpdir: art_dir = Path(tmpdir) part = TOSAPartitioner( - TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir)) + tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir)) ) _ = to_edge_transform_and_lower(aten_gm, partitioner=[part]) # Expect exactly one .tosa file in the artefact dir diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py new file mode 100644 index 00000000000..d18a1d39e45 --- /dev/null +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +def get_selective_quantizer(modules): + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + for module in modules: + quantizer.set_module_type(module, None) + + return Quantize(quantizer, get_symmetric_quantization_config()) + + +def test_qdq_squeezed_fp_op(): + """Test that a float operation surrounded by quantize-dequantize pairs + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q + |_____Non-delegated____| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = AddSigmoidMul() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid])) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + }, + ) + pipeline.run() + + +class MulAddSigmoidConv(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + self.conv = torch.nn.Conv1d(3, 3, 1) + + def forward(self, x, y): + return self.conv(self.sigmoid(x + y * x)) + + +def test_quantized_to_float_transition(): + """Test that a model executing quantized ops followed by float ops + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv + |____Non-delegated___| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = MulAddSigmoidConv() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args( + "quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d]) + ) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + ) + pipeline.run() diff --git a/backends/arm/test/misc/test_save_exported_model.py b/backends/arm/test/misc/test_save_exported_model.py new file mode 100644 index 00000000000..f393fca920c --- /dev/null +++ b/backends/arm/test/misc/test_save_exported_model.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.tosa import TosaSpecification +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class SimpleModule(torch.nn.Module): + example_inputs = (torch.randn(1, 10),) + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def test_save_load_exported_int_model(): + module = SimpleModule().eval() + example_inputs = module.example_inputs + exported_module = torch.export.export(module, example_inputs) + + # Set up quantizer + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + # Quantize model + prepared_module = prepare_pt2e(exported_module.module(), quantizer) + prepared_module(*example_inputs) + quantized_module = convert_pt2e(prepared_module) + quantized_exported_module = torch.export.export(quantized_module, example_inputs) + + base_path = "arm_test/misc/" + if not os.path.exists(base_path): + os.makedirs(base_path) + file_path = base_path + "exported_module.pt2" + # Verify that we can save the model + torch.export.save(quantized_exported_module, file_path) + + # Verify that we can load the model back + loaded_model = torch.export.load(file_path) + for original_node, loaded_node in zip( + quantized_exported_module.graph.nodes, loaded_model.graph.nodes + ): + # Verify that the custom metadata is preserved after save/load + assert original_node.meta.get("custom", {}) == loaded_node.meta.get( + "custom", {} + ) + if original_node.target == torch.ops.aten.linear.default: + assert ArmAnnotationInfo.CUSTOM_META_KEY in original_node.meta.get( + "custom", {} + ) diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index 190c50f4aa1..91a5bc19728 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -5,7 +5,11 @@ import unittest -from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification +from executorch.backends.arm.tosa.specification import ( + Tosa_1_00, + TosaSpecification, + TosaSpecMapping, +) from parameterized import parameterized # type: ignore[import-untyped] @@ -66,3 +70,100 @@ def test_correct_string_representation(self, version_string: str): tosa_spec = TosaSpecification.create_from_string(version_string) assert isinstance(tosa_spec, Tosa_1_00) assert f"{tosa_spec}" == version_string + + +class TestTosaSpecMapping(unittest.TestCase): + """Tests the TosaSpecMapping class""" + + def test_mapping(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + # check that the mapping is correct + vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + assert vals == ["A"] + assert len(vals) == 1 + + def test_mapping_multiple(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B") + # check that the mapping is correct + vals = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + assert vals == ["A", "B"] + assert len(vals) == 2 + + def test_mapping_different_profiles(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B") + # check that the mapping is correct + vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP")) + + assert vals_int == ["A"] + assert vals_fp == ["B"] + assert len(vals_int) == 1 + assert len(vals_fp) == 1 + + def test_mapping_different_profiles_combined_consumer(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "B") + # check that the mapping is correct + combined_vals = mapping.get( + TosaSpecification.create_from_string("TOSA-1.0+INT+FP") + ) + + assert "A" in combined_vals + assert "B" in combined_vals + assert len(combined_vals) == 2 + + def test_mapping_no_spec(self): + mapping = TosaSpecMapping() + with self.assertRaises(KeyError): + mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + def test_mapping_no_values_for_spec(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A") + with self.assertRaises(KeyError): + mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + def test_spec_with_different_profiles(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+FP"), "A") + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "B") + # check that the mapping is correct + vals_int = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT")) + vals_fp = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+FP")) + vals_int_fp = mapping.get( + TosaSpecification.create_from_string("TOSA-1.0+INT+FP") + ) + + assert vals_fp == ["A"] + assert vals_int == ["B"] + assert len(vals_int) == 1 + assert len(vals_fp) == 1 + assert len(vals_int_fp) == 2 + + def test_combined_profiles(self): + mapping = TosaSpecMapping() + with self.assertRaises(ValueError): + # Don't allow multiple profiles in a single spec + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT+FP"), "A") + + def test_spec_add_with_extension(self): + mapping = TosaSpecMapping() + with self.assertRaises(ValueError): + mapping.add( + TosaSpecification.create_from_string("TOSA-1.0.0+INT+int16"), "A" + ) + + def test_spec_non_canonical_key(self): + mapping = TosaSpecMapping() + mapping.add(TosaSpecification.create_from_string("TOSA-1.0+INT"), "A") + + val = mapping.get(TosaSpecification.create_from_string("TOSA-1.0+INT+u55")) + assert val == ["A"] diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index 9506fe727db..6444b8417f2 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -7,7 +7,9 @@ from typing import Tuple import torch -from diffusers.models.transformers import SD3Transformer2DModel +from diffusers.models.transformers import ( # type: ignore[import-not-found] + SD3Transformer2DModel, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( @@ -37,7 +39,8 @@ class TestSD3Transformer2DModel: ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, } def _prepare_inputs( diff --git a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py index a3c3a018131..5d33576a817 100644 --- a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py +++ b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py @@ -7,8 +7,12 @@ from typing import Tuple import torch -from diffusers.models.autoencoders import AutoencoderKL -from diffusers.utils.testing_utils import floats_tensor +from diffusers.models.autoencoders import ( # type: ignore[import-not-found] + AutoencoderKL, +) +from diffusers.utils.testing_utils import ( # type: ignore[import-not-found] + floats_tensor, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 6302528e4ae..f5a4c8c5053 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -18,7 +18,7 @@ VgfPipeline, ) -from torchaudio.models import Conformer +from torchaudio.models import Conformer # type: ignore[import-untyped] input_t = Tuple[torch.Tensor, torch.IntTensor] # Input x, y diff --git a/backends/arm/test/models/test_deit_tiny_arm.py b/backends/arm/test/models/test_deit_tiny_arm.py index 22685a079bd..b95c31f628a 100644 --- a/backends/arm/test/models/test_deit_tiny_arm.py +++ b/backends/arm/test/models/test_deit_tiny_arm.py @@ -7,7 +7,7 @@ from typing import Tuple -import timm +import timm # type: ignore[import-untyped] import torch @@ -19,11 +19,13 @@ VgfPipeline, ) -from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from torchvision import transforms +from timm.data import ( # type: ignore[import-untyped] + IMAGENET_INCEPTION_MEAN, + IMAGENET_INCEPTION_STD, +) +from torchvision import transforms # type: ignore[import-untyped] logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) deit_tiny = timm.models.deit.deit_tiny_patch16_224(pretrained=True) diff --git a/backends/arm/test/models/test_inception_v3_arm.py b/backends/arm/test/models/test_inception_v3_arm.py index 2cb180a87ea..13dfac3199f 100644 --- a/backends/arm/test/models/test_inception_v3_arm.py +++ b/backends/arm/test/models/test_inception_v3_arm.py @@ -5,11 +5,12 @@ from typing import Tuple -import common import pytest import torch +from executorch.backends.arm.test import common + from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -18,7 +19,7 @@ VgfPipeline, ) -from torchvision import models, transforms +from torchvision import models, transforms # type: ignore[import-untyped] ic3 = models.inception_v3(weights=models.Inception_V3_Weights) ic3 = ic3.eval() diff --git a/backends/arm/test/models/test_lstm_arm.py b/backends/arm/test/models/test_lstm_arm.py index d9691efab25..42744c151fc 100644 --- a/backends/arm/test/models/test_lstm_arm.py +++ b/backends/arm/test/models/test_lstm_arm.py @@ -5,9 +5,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -16,6 +21,9 @@ VgfPipeline, ) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize + from torch.nn.quantizable.modules import rnn input_t = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] # (h0, c0) @@ -134,3 +142,71 @@ def test_lstm_vgf_FP(): use_to_edge_transform_and_lower=True, ) pipeline.run() + + +def get_symmetric_a16w8_lstm_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=2**-16 + ) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=2**-16 + ), + ) + + +def test_lstm_16a8w_tosa_INT(): + """Test LSTM model with 16A8W quantization (16-bit activations, 8-bit weights)""" + + pipeline = TosaPipelineINT[input_t]( + TestLSTM.lstm, + TestLSTM.model_example_inputs, + aten_op=[], + exir_op=[], + per_channel_quantization=False, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer()) + pipeline.run() + + +@pytest.mark.xfail( + reason="MLETORCH-1452: AssertionError: Output 0 does not match reference output." +) +@common.XfailIfNoCorstone300 +def test_lstm_16a8w_u55_INT(): + pipeline = EthosU55PipelineINT[input_t]( + TestLSTM.lstm, + TestLSTM.model_example_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + + pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer()) + pipeline.run() + + +@common.XfailIfNoCorstone320 +def test_lstm_16a8w_u85_INT(): + pipeline = EthosU85PipelineINT[input_t]( + TestLSTM.lstm, + TestLSTM.model_example_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.change_args("quantize", get_symmetric_a16w8_lstm_quantizer()) + pipeline.run() diff --git a/backends/arm/test/models/test_mobilenet_v3_arm.py b/backends/arm/test/models/test_mobilenet_v3_arm.py index f3a8f27428b..0a9c5ba27fc 100644 --- a/backends/arm/test/models/test_mobilenet_v3_arm.py +++ b/backends/arm/test/models/test_mobilenet_v3_arm.py @@ -5,11 +5,12 @@ from typing import Tuple -import common import pytest import torch +from executorch.backends.arm.test import common + from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -18,7 +19,7 @@ VgfPipeline, ) -from torchvision import models, transforms +from torchvision import models, transforms # type: ignore[import-untyped] mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights) mv3 = mv3.eval() diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index 4896074b544..7d1ae64b63e 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, - {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data @@ -111,8 +110,10 @@ def test_nn_functional_INT(test_data): ) pipeline.pop_stage("check.aten") pipeline.pop_stage("check_count.exir") - pipeline.pop_stage("check.quant_nodes") - pipeline.pop_stage("check_not.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check_not.quant_nodes"): + pipeline.pop_stage("check_not.quant_nodes") try: pipeline.run() except RuntimeError as e: diff --git a/backends/arm/test/models/test_nn_modules.py b/backends/arm/test/models/test_nn_modules.py index 8192ec6887b..a1e1f6431d9 100644 --- a/backends/arm/test/models/test_nn_modules.py +++ b/backends/arm/test/models/test_nn_modules.py @@ -147,8 +147,10 @@ def test_nn_Modules_INT(test_data): ) pipeline.pop_stage("check.aten") pipeline.pop_stage("check_count.exir") - pipeline.pop_stage("check.quant_nodes") - pipeline.pop_stage("check_not.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check_not.quant_nodes"): + pipeline.pop_stage("check_not.quant_nodes") try: pipeline.run() except RuntimeError as e: diff --git a/backends/arm/test/models/test_nss.py b/backends/arm/test/models/test_nss.py new file mode 100644 index 00000000000..5f7db548109 --- /dev/null +++ b/backends/arm/test/models/test_nss.py @@ -0,0 +1,140 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +from huggingface_hub import hf_hub_download + +from ng_model_gym.usecases.nss.model.model_blocks import ( # type: ignore[import-not-found,import-untyped] + AutoEncoderV1, +) + +input_t = Tuple[torch.Tensor] # Input x + + +class NSS(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.auto_encoder = AutoEncoderV1() + + +def nss() -> AutoEncoderV1: + """Get an instance of NSS with weights loaded.""" + + weights = hf_hub_download( + repo_id="Arm/neural-super-sampling", filename="nss_v0.1.0_fp32.pt" + ) + + nss_model = NSS() + nss_model.load_state_dict( + torch.load(weights, map_location=torch.device("cpu"), weights_only=True), + strict=False, + ) + return nss_model.auto_encoder + + +def example_inputs(): + return (torch.randn((1, 12, 544, 960)),) + + +def test_nss_tosa_FP(): + pipeline = TosaPipelineFP[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.add_stage_after("export", pipeline.tester.dump_operator_distribution) + pipeline.run() + + +def test_nss_tosa_INT(): + pipeline = TosaPipelineINT[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@pytest.mark.skip(reason="No support for aten_upsample_nearest2d_vec on U55") +@common.XfailIfNoCorstone300 +def test_nss_u55_INT(): + pipeline = EthosU55PipelineINT[input_t]( + nss().eval(), + example_inputs(), + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@pytest.mark.skip( + reason="Fails at input memory allocation for input shape: [1, 12, 544, 960]" +) +@common.XfailIfNoCorstone320 +def test_nss_u85_INT(): + pipeline = EthosU85PipelineINT[input_t]( + nss().eval(), + example_inputs(), + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@pytest.mark.xfail( + reason="[MLETORCH-1430]: Double types are not supported in buffers in MSL" +) +@common.SkipIfNoModelConverter +def test_nss_vgf_FP(): + pipeline = VgfPipeline[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + tosa_version="TOSA-1.0+FP", + use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_nss_vgf_INT(): + pipeline = VgfPipeline[input_t]( + nss().eval(), + example_inputs(), + aten_op=[], + exir_op=[], + tosa_version="TOSA-1.0+INT", + symmetric_io_quantization=True, + use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, + ) + pipeline.run() + + +ModelUnderTest = nss().eval() +ModelInputs = example_inputs() diff --git a/backends/arm/test/models/test_resnet18.py b/backends/arm/test/models/test_resnet18.py index 44abc1d34e1..3a40a3dfd06 100644 --- a/backends/arm/test/models/test_resnet18.py +++ b/backends/arm/test/models/test_resnet18.py @@ -17,7 +17,10 @@ ) from torchvision import transforms # type: ignore[import-untyped] -from torchvision.models import resnet18, ResNet18_Weights +from torchvision.models import ( # type: ignore[import-untyped] + resnet18, + ResNet18_Weights, +) model = resnet18(weights=ResNet18_Weights) model = model.eval() diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 7f9bbdba177..54a9a6ae676 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -126,10 +126,12 @@ def test_torch_fns_FP(test_data): xfails={ "nonzero": "torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u4, 0). " "Requires dynamic output shape.", + "eye": "ValueError: Failed processing buffer placeholder: aten_arange_start_step_1_pre_computed_common. " + "Is the original torch function supported?", "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", }, - strict=False, + strict=True, ) def test_torch_fns_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/models/test_w2l_arm.py b/backends/arm/test/models/test_w2l_arm.py index d62d92f5fa2..0eda5f45875 100644 --- a/backends/arm/test/models/test_w2l_arm.py +++ b/backends/arm/test/models/test_w2l_arm.py @@ -20,7 +20,7 @@ VgfPipeline, ) -from torchaudio import models +from torchaudio import models # type: ignore[import-untyped] input_t = Tuple[torch.Tensor] # Input x diff --git a/backends/arm/test/ops/test_adaptive_avg_pool2d.py b/backends/arm/test/ops/test_adaptive_avg_pool2d.py index 4411ce7f746..3e4fbcaa833 100644 --- a/backends/arm/test/ops/test_adaptive_avg_pool2d.py +++ b/backends/arm/test/ops/test_adaptive_avg_pool2d.py @@ -136,6 +136,20 @@ def test_adaptive_avg_pool2d_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +def test_adaptive_avg_pool2d_tosa_INT_a16w8(test_module): + """Test adaptive_avg_pool2d with int16 I/O quantization for TOSA INT.""" + model, input_tensor = test_module() + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op=[], + exir_op=exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone300 def test_adaptive_avg_pool2d_u55_INT(test_module): @@ -150,6 +164,27 @@ def test_adaptive_avg_pool2d_u55_INT(test_module): pipeline.run() +# Remove high_channel_count & output_1x1_from_19 due to 2MB SRAM access on U55 +u55_test_modules = test_modules +for key in ["high_channel_count", "output_1x1_from_19"]: + u55_test_modules.pop(key) + + +@common.parametrize("test_module", u55_test_modules) +@common.XfailIfNoCorstone300 +def test_adaptive_avg_pool2d_16a8w_u55_INT16(test_module): + """Test adaptive_avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone320 def test_adaptive_avg_pool2d_u85_INT(test_module): @@ -164,6 +199,21 @@ def test_adaptive_avg_pool2d_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_adaptive_avg_pool2d_16a8w_u85_INT16(test_module): + """Test adaptive_avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter def test_adaptive_avg_pool2d_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_any.py b/backends/arm/test/ops/test_any.py index 3eccff0a64e..9d973a27d41 100644 --- a/backends/arm/test/ops/test_any.py +++ b/backends/arm/test/ops/test_any.py @@ -149,8 +149,6 @@ def test_any_tosa_INT(test_data: input_t1): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -181,8 +179,6 @@ def test_any_u85_INT(test_data: input_t1): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -211,6 +207,4 @@ def test_any_vgf_INT(test_data: input_t1): op.exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_arange.py b/backends/arm/test/ops/test_arange.py index 33cca542922..3816db3a53c 100644 --- a/backends/arm/test/ops/test_arange.py +++ b/backends/arm/test/ops/test_arange.py @@ -98,7 +98,6 @@ def test_arange_start_step_tosa_INT(test_data: test_data_t): ArangeAdd.aten_op, ArangeAdd.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -111,7 +110,6 @@ def test_arange_start_step_u55_INT(test_data: test_data_t): input_data(), ArangeAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -124,7 +122,6 @@ def test_arange_start_step_u85_INT(test_data: test_data_t): input_data(), ArangeAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 797ce26ea7a..0ed7f117ce7 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -153,6 +153,21 @@ def test_avg_pool2d_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +def test_avg_pool2d_tosa_INT_a16w8(test_module): + """Test avg_pool2d operation with int16 I/O quantization for TOSA INT.""" + model, input_tensor = test_module() + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + tosa_extensions=["int16"], + run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone300 def test_avg_pool2d_u55_INT(test_module): @@ -167,6 +182,23 @@ def test_avg_pool2d_u55_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone300 +def test_avg_pool2d_16a8w_u55_INT16(test_module): + """Test avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone320 def test_avg_pool2d_u85_INT(test_module): @@ -181,6 +213,23 @@ def test_avg_pool2d_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_avg_pool2d_16a8w_u85_INT16(test_module): + """Test avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter def test_avg_pool2d_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_bitwise.py b/backends/arm/test/ops/test_bitwise.py index f9b20e5dbdd..b80b87fdae5 100644 --- a/backends/arm/test/ops/test_bitwise.py +++ b/backends/arm/test/ops/test_bitwise.py @@ -109,8 +109,8 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): class AndScalar(BitwiseBinaryScalar): - aten_op = "torch.ops.aten.bitwise_and.Scalar" # Tensor because it gets converted from Scalar -> Tensor in lowering + aten_op = "torch.ops.aten.bitwise_and.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor" exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_and_Scalar" @@ -119,8 +119,8 @@ def forward(self, tensor: torch.Tensor, scalar: int): class XorScalar(BitwiseBinaryScalar): - aten_op = "torch.ops.aten.bitwise_xor.Scalar" # Tensor because it gets converted from Scalar -> Tensor in lowering + aten_op = "torch.ops.aten.bitwise_xor.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Tensor" exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_xor_Scalar" @@ -129,8 +129,8 @@ def forward(self, tensor: torch.Tensor, scalar: int): class OrScalar(BitwiseBinaryScalar): - aten_op = "torch.ops.aten.bitwise_or.Scalar" # Tensor because it gets converted from Scalar -> Tensor in lowering + aten_op = "torch.ops.aten.bitwise_or.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Tensor" exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_bitwise_or_Scalar" @@ -174,8 +174,6 @@ def test_bitwise_and_tensor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -190,8 +188,6 @@ def test_bitwise_and_scalar_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -239,8 +235,6 @@ def test_bitwise_and_scalar_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -256,8 +250,6 @@ def test_bitwise_and_tensor_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -296,8 +288,6 @@ def test_bitwise_and_tensor_vgf_INT(test_data: input_t2): qtol=0, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -314,8 +304,6 @@ def test_bitwise_and_scalar_vgf_INT(test_data: input_t2): qtol=0, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -355,8 +343,6 @@ def test_bitwise_xor_tensor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -371,8 +357,6 @@ def test_bitwise_xor_scalar_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -420,8 +404,6 @@ def test_bitwise_xor_tensor_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -437,8 +419,6 @@ def test_bitwise_xor_scalar_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -477,8 +457,6 @@ def test_bitwise_xor_tensor_vgf_INT(test_data: input_t2): qtol=0, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -495,8 +473,6 @@ def test_bitwise_xor_scalar_vgf_INT(test_data: input_t2): qtol=0, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -536,8 +512,6 @@ def test_bitwise_or_tensor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -552,8 +526,6 @@ def test_bitwise_or_scalar_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -601,8 +573,6 @@ def test_bitwise_or_tensor_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -618,8 +588,6 @@ def test_bitwise_or_scalar_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -658,8 +626,6 @@ def test_bitwise_or_tensor_vgf_INT(test_data: input_t2): qtol=0, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -676,8 +642,6 @@ def test_bitwise_or_scalar_vgf_INT(test_data: input_t2): qtol=0, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_bitwise_not.py b/backends/arm/test/ops/test_bitwise_not.py index 4f48bc134ba..f9d743bdc8a 100644 --- a/backends/arm/test/ops/test_bitwise_not.py +++ b/backends/arm/test/ops/test_bitwise_not.py @@ -60,8 +60,6 @@ def test_bitwise_not_tosa_INT(test_data: Tuple): aten_op=aten_op, exir_op=exir_op, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -87,8 +85,6 @@ def test_bitwise_not_u85_INT(test_data: Tuple): aten_ops=aten_op, exir_ops=exir_op, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -115,6 +111,4 @@ def test_bitwise_not_vgf_INT(test_data: Tuple): exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index a5561802e44..88c12dd8d6c 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -84,6 +84,22 @@ def test_clamp_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_clamp_tosa_INT_a16w8(test_data): + """Test clamp operation with int16 I/O quantization for TOSA INT.""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_clamp_u55_INT(test_data): @@ -102,6 +118,25 @@ def test_clamp_u55_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_clamp_16a8w_u55_INT16(test_data): + """Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 def test_clamp_u85_INT(test_data): @@ -120,6 +155,25 @@ def test_clamp_u85_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_clamp_16a8w_u85_INT16(test_data): + """Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter def test_clamp_vgf_FP(test_data): diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py new file mode 100644 index 00000000000..77405354bd4 --- /dev/null +++ b/backends/arm/test/ops/test_cond.py @@ -0,0 +1,263 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +aten_op = "torch.ops.higher_order.cond" +exir_op = "torch.ops.higher_order.cond" + +input_t1 = Tuple[torch.Tensor] +input_t2 = Tuple[torch.Tensor, torch.Tensor] + + +class CondZeroArgsOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch() -> torch.Tensor: + return torch.zeros(10) + + def false_branch() -> torch.Tensor: + return torch.ones(10) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, []) + + +class CondOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return torch.cos(arg) + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgBufferOneOutput(torch.nn.Module): + def __init__(self, *args: common.Any, **kwargs: common.Any) -> None: + super().__init__(*args, **kwargs) + self.buffer = torch.rand(2, 3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor: + return torch.sin(arg) + buffer + + def false_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor: + return torch.cos(arg) + buffer + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x, self.buffer]) + + +class CondOneArgAndScalarOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def true_branch(arg: torch.Tensor) -> torch.Tensor: + return arg + 1.0 + + def false_branch(arg: torch.Tensor) -> torch.Tensor: + return arg - 1.0 + + predicate = x.sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondOneArgTwoOutputs(torch.nn.Module): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg + torch.sin(arg), arg - torch.sin(arg) + + def false_branch(arg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return arg - arg.mean(), arg + arg.mean() + + predicate = x.flatten().sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [x]) + + +class CondNestedOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def inner_true(arg: torch.Tensor) -> torch.Tensor: + return arg + torch.full((1,), (1.0)) + + def inner_false(arg: torch.Tensor) -> torch.Tensor: + return arg - torch.full((1,), (1.0)) + + def outer_true(arg: torch.Tensor) -> torch.Tensor: + inner_predicate = arg.mean() > 0 + return torch.cond(inner_predicate, inner_true, inner_false, [arg]) + + def outer_false(arg: torch.Tensor) -> torch.Tensor: + return arg * torch.full((1,), (1.0)) + + predicate = x.sum() > 0 + return torch.cond(predicate, outer_true, outer_false, [x]) + + +class CondMultipleOneArgOneOutput(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + def first_true(arg: torch.Tensor) -> torch.Tensor: + return arg.sigmoid() + + def first_false(arg: torch.Tensor) -> torch.Tensor: + return arg.relu() + + first_predicate = x.sum() > 0 + intermediate = torch.cond(first_predicate, first_true, first_false, [x]) + + def second_true(arg: torch.Tensor) -> torch.Tensor: + return arg.sin() + + def second_false(arg: torch.Tensor) -> torch.Tensor: + return arg.cos() + + second_predicate = intermediate.mean() > 0 + return torch.cond(second_predicate, second_true, second_false, [intermediate]) + + +class CondTwoArgsOneOutput(torch.nn.Module): + def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor: + def true_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l + arg_r + + def false_branch(arg_l: torch.Tensor, arg_r: torch.Tensor) -> torch.Tensor: + return arg_l - arg_r + + predicate = (lhs - rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +class CondTwoArgsTwoOutputs(torch.nn.Module): + def forward( + self, lhs: torch.Tensor, rhs: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def true_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + return arg_l + arg_r, arg_l * arg_r + + def false_branch( + arg_l: torch.Tensor, arg_r: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + diff = arg_l - arg_r + return diff, arg_l + diff + + predicate = (lhs * rhs).sum() > 0 + return torch.cond(predicate, true_branch, false_branch, [lhs, rhs]) + + +def _single_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t1]]: + def _create() -> tuple[torch.nn.Module, input_t1]: + return module_factory(), (torch.randn(2, 3),) + + return _create + + +def _dual_input_case( + module_factory: Callable[[], torch.nn.Module] +) -> Callable[[], tuple[torch.nn.Module, input_t2]]: + def _create() -> tuple[torch.nn.Module, input_t2]: + return module_factory(), (torch.randn(2, 3), torch.randn(2, 3)) + + return _create + + +test_cases: dict[str, Callable[[], tuple[torch.nn.Module, tuple]]] = { + "zero_args_one_output": _single_input_case(CondZeroArgsOneOutput), + "one_arg_one_output": _single_input_case(CondOneArgOneOutput), + "one_arg_const_one_output": _single_input_case(CondOneArgBufferOneOutput), + "one_arg_and_scalar_one_output": _single_input_case(CondOneArgAndScalarOneOutput), + "one_arg_two_outputs": _single_input_case(CondOneArgTwoOutputs), + "two_args_one_output": _dual_input_case(CondTwoArgsOneOutput), + "two_args_two_outputs": _dual_input_case(CondTwoArgsTwoOutputs), + "nested_one_arg_one_output": _single_input_case(CondNestedOneArgOneOutput), + "multiple_one_arg_one_output": _single_input_case(CondMultipleOneArgOneOutput), +} + + +def _make_calibration_samples( + module: torch.nn.Module, example_inputs: tuple +) -> tuple[tuple[torch.Tensor, ...], ...]: + """Return one example input that triggers the if branch, and one that triggers the else branch.""" + + if isinstance(module, CondTwoArgsOneOutput): + # Predicate is sum(lhs-rhs) > 0 + lhs, rhs = example_inputs + if_example_inputs = (lhs, rhs) + else_example_inputs = (rhs, lhs) + elif isinstance(module, CondTwoArgsTwoOutputs): + # Predicate is sum(lhs*rhs) > 0 + lhs, rhs = example_inputs + if_example_inputs = (lhs, rhs) + else_example_inputs = (lhs, -rhs) + else: + # Predicate is sum(x) > 0 + (x,) = example_inputs + if_example_inputs = (x,) + else_example_inputs = (-x,) + + return (if_example_inputs, else_example_inputs) + + +@common.parametrize( + "case", + test_cases, + xfails={ + "one_arg_and_scalar_one_output": "Scalars become get_attr nodes that are not supported.", + "nested_one_arg_one_output": "Not fully delegated.", + }, +) +def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineFP[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + # Make sure no cond ops are left after partitioning. + pipeline.add_stage_after( + "to_edge_transform_and_lower", + ArmTester.check_not, + pipeline.tester, + ["torch.ops.higher_order.cond"], + ) + pipeline.run() + + +@common.parametrize( + "case", + test_cases, + xfails={ + "zero_args_one_output": "Since the submodules have no input, the tracer fails finding a fake tensor mode," + " and traces the graph with real tensors, which tosa.RESCALE can't handle.", + "one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.", + "nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0", + }, +) +def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineINT[tuple]( + module, example_inputs, aten_op, tosa_extensions=["cf"] + ) + calibration_samples = _make_calibration_samples(module, example_inputs) + quant_stage_pos = pipeline.find_pos("quantize") + quant_stage = pipeline._stages[quant_stage_pos].args[0] + quant_stage.calibration_samples = calibration_samples + + # Make sure no cond ops are left after partitioning. + pipeline.add_stage_after( + "to_edge_transform_and_lower", + ArmTester.check_not, + pipeline.tester, + ["torch.ops.higher_order.cond"], + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index d70249c31d1..437c4bee9ef 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -77,6 +77,20 @@ def test_constant_pad_nd_tosa_INT(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple): + """Test constant_pad_nd op with int16 I/O quantization for TOSA INT.""" + test_data, padding, value = test_data() + pipeline = TosaPipelineINT[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter def test_constant_pad_nd_vgf_FP(test_data: Tuple): diff --git a/backends/arm/test/ops/test_eq.py b/backends/arm/test/ops/test_eq.py index 8f783240a2c..e49f09471fa 100644 --- a/backends/arm/test/ops/test_eq.py +++ b/backends/arm/test/ops/test_eq.py @@ -121,6 +121,30 @@ def test_eq_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_eq_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_eq_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_eq_scalar_u55_INT_tensor(test_module): @@ -150,14 +174,7 @@ def test_eq_scalar_u55_INT(test_module): pipeline.run() -@common.parametrize( - "test_module", - test_data_tensor, - xfails={ - "eq_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85", - }, - strict=False, -) +@common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone320 def test_eq_scalar_u85_INT_tensor(test_module): pipeline = EthosU85PipelineINT[input_t]( @@ -169,14 +186,7 @@ def test_eq_scalar_u85_INT_tensor(test_module): pipeline.run() -@common.parametrize( - "test_module", - test_data_scalar, - xfails={ - "eq_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85", - }, - strict=False, -) +@common.parametrize("test_module", test_data_scalar) @common.XfailIfNoCorstone320 def test_eq_scalar_u85_INT(test_module): pipeline = EthosU85PipelineINT[input_t]( @@ -188,6 +198,42 @@ def test_eq_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_eq_tensor_16a8w_u85_INT16(test_module): + """Test eq operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_eq_scalar_16a8w_u85_INT16(test_module): + """Test eq operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_eq_scalar_vgf_FP_tensor(test_module): diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index eef32259c10..f3ba4113db1 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -68,7 +68,8 @@ def test_eye_tosa_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -82,7 +83,8 @@ def test_eye_u55_INT(test_data: test_data_t): EyeAdd.aten_op, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -95,8 +97,9 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") - pipeline.pop_stage("check.quant_nodes") + ) + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -132,7 +135,8 @@ def test_eye_vgf_INT(test_data: test_data_t): EyeAdd.aten_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 8ab063e9957..d0cf162a232 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -117,7 +117,6 @@ def test_full_like_tosa_INT(test_data: Tuple): aten_op=[], exir_op=exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_ge.py b/backends/arm/test/ops/test_ge.py index ede5be76eda..b3cc1df34c9 100644 --- a/backends/arm/test/ops/test_ge.py +++ b/backends/arm/test/ops/test_ge.py @@ -121,6 +121,30 @@ def test_ge_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_ge_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_ge_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_ge_tensor_u55_INT(test_module): @@ -180,6 +204,42 @@ def test_ge_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_ge_tensor_16a8w_u85_INT16(test_module): + """Test ge operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_ge_scalar_16a8w_u85_INT16(test_module): + """Test ge operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_ge_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_gt.py b/backends/arm/test/ops/test_gt.py index 0e50b6b78be..aee617f9767 100644 --- a/backends/arm/test/ops/test_gt.py +++ b/backends/arm/test/ops/test_gt.py @@ -122,6 +122,30 @@ def test_gt_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_gt_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_gt_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_gt_tensor_u55_INT(test_module): @@ -181,6 +205,42 @@ def test_gt_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_gt_tensor_16a8w_u85_INT16(test_module): + """Test gt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_gt_scalar_16a8w_u85_INT16(test_module): + """Test gt operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_gt_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_le.py b/backends/arm/test/ops/test_le.py index fd0e63e9beb..cc8ddfc4da2 100644 --- a/backends/arm/test/ops/test_le.py +++ b/backends/arm/test/ops/test_le.py @@ -122,6 +122,30 @@ def test_le_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_le_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_le_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_le_tensor_u55_INT_not_delegated(test_module): @@ -184,6 +208,42 @@ def test_le_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_le_tensor_16a8w_u85_INT16(test_module): + """Test le operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_le_scalar_16a8w_u85_INT16(test_module): + """Test le operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_le_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_logical.py b/backends/arm/test/ops/test_logical.py index e772840e6e6..8c290c28908 100644 --- a/backends/arm/test/ops/test_logical.py +++ b/backends/arm/test/ops/test_logical.py @@ -111,8 +111,6 @@ def test_logical_and_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -141,8 +139,6 @@ def test_logical_and_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -169,8 +165,6 @@ def test_logical_and_vgf_INT(test_data: input_t2): And().exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -204,8 +198,6 @@ def test_logical_xor_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -234,8 +226,6 @@ def test_logical_xor_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -262,8 +252,6 @@ def test_logical_xor_vgf_INT(test_data: input_t2): Xor().exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -297,8 +285,6 @@ def test_logical_or_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -327,8 +313,6 @@ def test_logical_or_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -355,8 +339,6 @@ def test_logical_or_vgf_INT(test_data: input_t2): Or().exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -390,8 +372,6 @@ def test_logical_not_tosa_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -420,8 +400,6 @@ def test_logical_not_u85_INT(test_data: input_t2): rtol=0, qtol=0, ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -448,6 +426,4 @@ def test_logical_not_vgf_INT(test_data: input_t2): Not().exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("quantize") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_lshift.py b/backends/arm/test/ops/test_lshift.py index 3af49cd4dc2..1d4224a8efe 100644 --- a/backends/arm/test/ops/test_lshift.py +++ b/backends/arm/test/ops/test_lshift.py @@ -91,7 +91,6 @@ def test_bitwise_left_shift_tensor_tosa_INT_scalar(test_data): LshiftScalar.torch_op_INT, LshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -104,7 +103,6 @@ def test_bitwise_left_shift_tensor_u55_INT_scalar(test_data): LshiftScalar.torch_op_INT, LshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -117,7 +115,6 @@ def test_bitwise_left_shift_tensor_u85_INT_scalar(test_data): LshiftScalar.torch_op_INT, LshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -144,7 +141,6 @@ def test_bitwise_left_shift_tensor_vgf_INT_scalar(test_data: scalar_input_t): LshiftScalar.exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -171,7 +167,6 @@ def test_bitwise_left_shift_tensor_tosa_INT(test_data): LshiftTensor.torch_op, LshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -184,7 +179,6 @@ def test_bitwise_left_shift_tensor_u55_INT(test_data): LshiftTensor.torch_op, LshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -197,7 +191,6 @@ def test_bitwise_left_shift_tensor_u85_INT(test_data): LshiftTensor.torch_op, LshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -224,5 +217,4 @@ def test_bitwise_left_shift_tensor_vgf_INT(test_data: tensor_input_t): LshiftTensor.exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_lt.py b/backends/arm/test/ops/test_lt.py index d0ed1a34185..22958208bcd 100644 --- a/backends/arm/test/ops/test_lt.py +++ b/backends/arm/test/ops/test_lt.py @@ -122,6 +122,30 @@ def test_lt_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_lt_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_lt_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_lt_tensor_u55_INT_not_delegated(test_module): @@ -181,6 +205,42 @@ def test_lt_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_lt_tensor_16a8w_u85_INT16(test_module): + """Test lt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_lt_scalar_16a8w_u85_INT16(test_module): + """Test lt operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_lt_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 559932848e4..21619afa7a3 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -133,6 +133,20 @@ def test_max_pool2d_tosa_INT(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_max_pool2d_tosa_INT_a16w8(test_data: torch.Tensor): + """Test max_pool2d operation with int16 I/O quantization for TOSA INT.""" + test_data, model_params = test_data() + pipeline = TosaPipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_max_pool2d_u55_INT(test_data: torch.Tensor): @@ -145,6 +159,23 @@ def test_max_pool2d_u55_INT(test_data: torch.Tensor): ).run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_max_pool2d_16a8w_u55_INT16(test_data: torch.Tensor): + """Test max_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + test_data, model_params = test_data() + pipeline = EthosU55PipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 def test_max_pool2d_u85_INT(test_data: torch.Tensor): @@ -157,6 +188,23 @@ def test_max_pool2d_u85_INT(test_data: torch.Tensor): ).run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_max_pool2d_16a8w_u85_INT16(test_data: torch.Tensor): + """Test max_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + test_data, model_params = test_data() + pipeline = EthosU85PipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + reject_data_suite = { "reject_1": lambda: (MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), "reject_2": lambda: (MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index 02447e40c4e..2e40a244983 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -187,7 +187,6 @@ def test_mul_tensor_tosa_INT_int32(test_data: torch.Tensor): aten_op, exir_op=[], ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -224,7 +223,6 @@ def test_mul_tensor_u55_INT_int32(test_data: torch.Tensor): aten_op, exir_ops=[], ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -237,7 +235,6 @@ def test_mul_tensor_u85_INT_int32(test_data: torch.Tensor): aten_op, exir_ops=[], ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -285,7 +282,6 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor): exir_op=[], tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_ones.py b/backends/arm/test/ops/test_ones.py index f4dafca5e10..53351bfff53 100644 --- a/backends/arm/test/ops/test_ones.py +++ b/backends/arm/test/ops/test_ones.py @@ -65,7 +65,10 @@ def test_ones_tosa_INT(test_data: test_data_t): input_data(), OnesAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -79,7 +82,10 @@ def test_ones_u55_INT(test_data: test_data_t): OnesAdd.aten_op, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -92,8 +98,11 @@ def test_ones_u85_INT(test_data: test_data_t): input_data(), OnesAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") - pipeline.pop_stage("check.quant_nodes") + ) + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -133,5 +142,8 @@ def test_ones_vgf_INT(test_data: test_data_t): OnesAdd.aten_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index c9fe32bf86c..8938ebcc27e 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -38,6 +38,10 @@ "rank_4": lambda: (torch.rand(1, 5, 1, 10), [0, 2, 3, 1]), "rank_4_2": lambda: (torch.rand(1, 2, 5, 10), [1, 0, 2, 3]), "rank_4_3": lambda: (torch.rand(1, 10, 10, 5), [2, 0, 1, 3]), + "rank_4_large": lambda: (torch.rand(2, 8, 64, 65), [0, 2, 3, 1]), + "rank_3_large": lambda: (torch.rand(16, 64, 65), [1, 2, 0]), + "reshape_large_1": lambda: (torch.rand(1, 1, 65537), [0, 2, 1]), + "reshape_large_2": lambda: (torch.rand(65537, 1, 1), [1, 2, 0]), } diff --git a/backends/arm/test/ops/test_remainder.py b/backends/arm/test/ops/test_remainder.py index 2cba9532cde..336d4db1a7b 100644 --- a/backends/arm/test/ops/test_remainder.py +++ b/backends/arm/test/ops/test_remainder.py @@ -24,7 +24,12 @@ def _nonzero_float_tensor(*shape: int) -> torch.Tensor: class Remainder(torch.nn.Module): input_t = Tuple[torch.Tensor | float, torch.Tensor | float] - test_cases = { + aten_op_tensor = "torch.ops.aten.remainder.Tensor" + exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_remainder_Tensor" + aten_op_scalar = "torch.ops.aten.remainder.Scalar" + exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_remainder_Scalar" + + test_cases_tensor = { "rank2_tensors": lambda: ( torch.randn(2, 3) * 7, _nonzero_float_tensor(2, 3), @@ -37,44 +42,49 @@ class Remainder(torch.nn.Module): torch.randn(4, 5, 1), _nonzero_float_tensor(1, 5, 6), ), - "scalar_rhs": lambda: ( + } + + test_cases_scalar = { + "scalar_pos": lambda: ( torch.randn(1, 2, 3, 4), 0.25, ), + "scalar_neg": lambda: ( + torch.randn(3, 4), + -0.25, + ), } def forward(self, x: torch.Tensor | float, y: torch.Tensor | float) -> torch.Tensor: return torch.remainder(x, y) -def _get_aten_op(test_data: Remainder.input_t): - if any(isinstance(x, float) for x in test_data): - return "torch.ops.aten.remainder.Scalar" - else: - return "torch.ops.aten.remainder.Tensor" - - -def _get_exir_op(test_data: Remainder.input_t): - if isinstance(test_data[1], float): - return "executorch_exir_dialects_edge__ops_aten_remainder_Scalar" - else: - return "executorch_exir_dialects_edge__ops_aten_remainder_Tensor" +@common.parametrize("test_data", Remainder.test_cases_tensor) +def test_remainder_tensor_tosa_FP(test_data): + data = test_data() + pipeline = TosaPipelineFP[Remainder.input_t]( + Remainder(), + data, + Remainder.aten_op_tensor, + Remainder.exir_op_tensor, + ) + pipeline.run() -@common.parametrize("test_data", Remainder.test_cases) -def test_remainder_tosa_FP(test_data): +@common.parametrize("test_data", Remainder.test_cases_scalar) +def test_remainder_scalar_tosa_FP(test_data): data = test_data() pipeline = TosaPipelineFP[Remainder.input_t]( Remainder(), data, - _get_aten_op(data), - _get_exir_op(data), + Remainder.aten_op_scalar, + Remainder.exir_op_scalar, ) pipeline.run() -@common.parametrize("test_data", Remainder.test_cases) -def test_remainder_tosa_INT(test_data): +@common.parametrize("test_data", Remainder.test_cases_tensor) +def test_remainder_tensor_tosa_INT(test_data): pipeline = TosaPipelineINT[Remainder.input_t]( Remainder(), test_data(), @@ -83,9 +93,30 @@ def test_remainder_tosa_INT(test_data): pipeline.run() -@common.parametrize("test_data", Remainder.test_cases) +@common.parametrize("test_data", Remainder.test_cases_scalar) +def test_remainder_scalar_tosa_INT(test_data): + pipeline = TosaPipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +@common.XfailIfNoCorstone300 +def test_remainder_tensor_u55_INT(test_data): + pipeline = EthosU55PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) @common.XfailIfNoCorstone300 -def test_remainder_u55_INT(test_data): +def test_remainder_scalar_u55_INT(test_data): pipeline = EthosU55PipelineINT[Remainder.input_t]( Remainder(), test_data(), @@ -94,9 +125,9 @@ def test_remainder_u55_INT(test_data): pipeline.run() -@common.parametrize("test_data", Remainder.test_cases) +@common.parametrize("test_data", Remainder.test_cases_tensor) @common.XfailIfNoCorstone320 -def test_remainder_u85_INT(test_data): +def test_remainder_tensor_u85_INT(test_data): pipeline = EthosU85PipelineINT[Remainder.input_t]( Remainder(), test_data(), @@ -105,23 +136,60 @@ def test_remainder_u85_INT(test_data): pipeline.run() -@common.parametrize("test_data", Remainder.test_cases) +@common.parametrize("test_data", Remainder.test_cases_scalar) +@common.XfailIfNoCorstone320 +def test_remainder_scalar_u85_INT(test_data): + pipeline = EthosU85PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) @common.SkipIfNoModelConverter -def test_remainder_vgf_FP(test_data): +def test_remainder_tensor_vgf_FP(test_data): data = test_data() pipeline = VgfPipeline[Remainder.input_t]( Remainder(), data, - _get_aten_op(data), - _get_exir_op(data), + Remainder.aten_op_tensor, + Remainder.exir_op_tensor, tosa_version="TOSA-1.0+FP", ) pipeline.run() -@common.parametrize("test_data", Remainder.test_cases) +@common.parametrize("test_data", Remainder.test_cases_scalar) +@common.SkipIfNoModelConverter +def test_remainder_scalar_vgf_FP(test_data): + data = test_data() + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + data, + Remainder.aten_op_scalar, + Remainder.exir_op_scalar, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_tensor) +@common.SkipIfNoModelConverter +def test_remainder_tensor_vgf_INT(test_data): + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + test_data(), + [], + tosa_version="TOSA-1.0+INT", + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases_scalar) @common.SkipIfNoModelConverter -def test_remainder_vgf_INT(test_data): +def test_remainder_scalar_vgf_INT(test_data): pipeline = VgfPipeline[Remainder.input_t]( Remainder(), test_data(), diff --git a/backends/arm/test/ops/test_rshift.py b/backends/arm/test/ops/test_rshift.py index f7a821e3a63..40258907b1e 100644 --- a/backends/arm/test/ops/test_rshift.py +++ b/backends/arm/test/ops/test_rshift.py @@ -91,7 +91,6 @@ def test_bitwise_right_shift_tensor_tosa_INT_scalar(test_data): RshiftScalar.torch_op_INT, RshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -104,7 +103,6 @@ def test_bitwise_right_shift_tensor_u55_INT_scalar(test_data): RshiftScalar.torch_op_INT, RshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") # Forced rounding in U55 HW causes off-by-one errors. pipeline.change_args("run_method_and_compare_outputs", inputs=test_data(), atol=1) @@ -120,7 +118,6 @@ def test_bitwise_right_shift_tensor_u85_INT_scalar(test_data): RshiftScalar.torch_op_INT, RshiftScalar.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -147,7 +144,6 @@ def test_bitwise_right_shift_tensor_vgf_INT_scalar(test_data): RshiftScalar.exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -174,7 +170,6 @@ def test_bitwise_right_shift_tensor_tosa_INT(test_data): RshiftTensor.torch_op, RshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -187,7 +182,6 @@ def test_bitwise_right_shift_tensor_u55_INT(test_data): RshiftTensor.torch_op, RshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") # Forced rounding in U55 HW causes off-by-one errors. pipeline.change_args("run_method_and_compare_outputs", inputs=test_data(), atol=1) @@ -203,7 +197,6 @@ def test_bitwise_right_shift_tensor_u85_INT(test_data): RshiftTensor.torch_op, RshiftTensor.exir_op, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -230,5 +223,4 @@ def test_bitwise_right_shift_tensor_vgf_INT(test_data): RshiftTensor.exir_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index 23bb9dc1a4b..9e2f024dcdd 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -8,9 +8,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -18,7 +23,8 @@ TosaPipelineINT, VgfPipeline, ) - +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.rsqrt.default" input_t1 = Tuple[torch.Tensor] # Input x @@ -104,3 +110,99 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_rsqrt_quantizer( + u55_config=False, per_channel_quantization=False +): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_tensor", Rsqrt.test_parameters) +@pytest.mark.xfail( + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." +) +def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 quantization""" + pipeline = TosaPipelineINT[input_t1]( + Rsqrt(), + test_tensor(), + aten_op, + exir_op=[], + per_channel_quantization=False, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False), + ) + # Run the pipeline + pipeline.run() + + +@common.parametrize("test_tensor", Rsqrt.test_parameters) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." +) +def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 quantization on U55""" + pipeline = EthosU55PipelineINT[input_t1]( + Rsqrt(), + test_tensor(), + aten_op, + exir_ops=[], + per_channel_quantization=True, + use_to_edge_transform_and_lower=True, + atol=1e-03, + rtol=1e-03, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=True), + ) + pipeline.run() + + +@common.parametrize("test_tensor", Rsqrt.test_parameters) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." +) +def test_rsqrt_16a8w_u85_INT16(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 quantization on U85""" + pipeline = EthosU85PipelineINT[input_t1]( + Rsqrt(), + test_tensor(), + aten_op, + exir_ops=[], + use_to_edge_transform_and_lower=True, + atol=1e-03, + rtol=1e-03, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False), + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_scalar_tensor.py b/backends/arm/test/ops/test_scalar_tensor.py index d5e5b365da1..356bcf508b7 100644 --- a/backends/arm/test/ops/test_scalar_tensor.py +++ b/backends/arm/test/ops/test_scalar_tensor.py @@ -73,7 +73,10 @@ def test_scalar_tensor_tosa_INT(test_data): tuple(data), ScalarTensor.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -112,7 +115,10 @@ def test_scalar_tensor_vgf_FP(test_data): pipeline.run() -@common.parametrize("test_data", int_test_data_suite) +@common.parametrize( + "test_data", + int_test_data_suite, +) @common.SkipIfNoModelConverter def test_scalar_tensor_vgf_INT(test_data): scalar, dtype, data = test_data() @@ -122,5 +128,8 @@ def test_scalar_tensor_vgf_INT(test_data): ScalarTensor.aten_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index c4f371a1a14..b3704c87fb6 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -435,5 +435,4 @@ def test_bitwise_right_shift_tensor_tosa_INT_inplace(): (torch.IntTensor(5),), aten_op="torch.ops.aten.bitwise_right_shift.Tensor", ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 1fdc4619131..b3b7fab5318 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -226,7 +226,6 @@ def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple): exir_op=[], ) pipeline.pop_stage("run_method_and_compare_outputs") - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/ops/test_upsample_bilinear2d.py b/backends/arm/test/ops/test_upsample_bilinear2d.py index 1edba708f1f..db440fcb3d4 100644 --- a/backends/arm/test/ops/test_upsample_bilinear2d.py +++ b/backends/arm/test/ops/test_upsample_bilinear2d.py @@ -7,7 +7,6 @@ import torch from executorch.backends.arm.test import common - from executorch.backends.arm.test.tester.test_pipeline import ( EthosU85PipelineINT, OpNotSupportedPipeline, @@ -196,6 +195,24 @@ def test_upsample_bilinear2d_vec_tosa_INT_Upsample( pipeline.run() +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_INT_a16w8( + test_data: torch.Tensor, +): + """Test upsample_bilinear2d vector op with int16 I/O quantization for TOSA INT.""" + test_data, size, scale_factor, compare_outputs = test_data + pipeline = TosaPipelineINT[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_u55) @common.XfailIfNoCorstone300 def test_upsample_bilinear2d_vec_U55_INT_Upsample_not_delegated( @@ -305,6 +322,27 @@ def test_upsample_bilinear2d_vec_U85_INT_UpsamplingBilinear2d( pipeline.run() +@common.parametrize("test_data", test_data_suite_Uxx) +@common.XfailIfNoCorstone320 +def test_upsample_bilinear2d_vec_U85_INT_a16w8( + test_data: input_t1, +): + """Test upsample_bilinear2d vec op with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + data, size, scale_factor, compare_outputs = test_data + + pipeline = EthosU85PipelineINT[input_t1]( + UpsamplingBilinear2d(size, scale_factor), + (data,), + aten_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter def test_upsample_bilinear2d_vgf_FP_UpsamplingBilinear2d(test_data: torch.Tensor): diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py index a39adefc168..e7da0643d0e 100644 --- a/backends/arm/test/ops/test_upsample_nearest2d.py +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -195,6 +195,22 @@ def test_upsample_nearest2d_vec_tosa_INT_interpolate(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_upsample_nearest2d_vec_tosa_INT_a16w8(test_data: torch.Tensor): + """Test upsample_nearest2d vector op with int16 I/O quantization for TOSA INT.""" + test_data, size, scale_factor, compare_outputs = test_data() + pipeline = TosaPipelineINT[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter def test_upsample_nearest2d_vgf_FP(test_data: torch.Tensor): diff --git a/backends/arm/test/ops/test_while.py b/backends/arm/test/ops/test_while.py new file mode 100644 index 00000000000..f66d8995683 --- /dev/null +++ b/backends/arm/test/ops/test_while.py @@ -0,0 +1,187 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Tuple + +import torch +import torch.fx + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +input_single = Tuple[torch.Tensor] +input_double = Tuple[torch.Tensor, torch.Tensor] + + +class WhileTwoInputsTwoOutputs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, lhs: torch.Tensor, rhs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + def cond_fn(lhs_val: torch.Tensor, rhs_val: torch.Tensor) -> torch.Tensor: + total = torch.sum(rhs_val) + zero = torch.zeros_like(total) + return torch.gt(total, zero).squeeze() + + def body_fn( + lhs_val: torch.Tensor, rhs_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_lhs = torch.add(lhs_val, rhs_val) + next_rhs = torch.sub(rhs_val, torch.full((1,), 1.0)) + return (next_lhs, next_rhs) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (lhs, rhs), + (), + ) + return result # type: ignore + + +class WhileOneInputOneBufferTwoOutputs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("threshold", torch.tensor((30.0,))) + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.lt(total, limit).squeeze() + + def body_fn( + value: torch.Tensor, limit: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return (torch.add(value, value), limit.clone()) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value, self.threshold), + (), + ) + return result # type: ignore + + +class WhileAdditionalArg(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("threshold", torch.tensor((30.0,))) + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.lt(total, limit).squeeze() + + def body_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + return torch.add(value, value) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value,), + (self.threshold,), + ) + return result # type: ignore + + +class WhileSingleCapturedOutput(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("threshold", torch.tensor((30.0,))) + + def forward(self, value: torch.Tensor) -> torch.Tensor: + def cond_fn(value: torch.Tensor, limit: torch.Tensor) -> torch.Tensor: + total = value.sum() + return torch.lt(total, limit).squeeze() + + def body_fn( + value: torch.Tensor, limit: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return (torch.add(value, value), limit.clone()) + + result = torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + (value, self.threshold), + (), + ) + return result[0] # type: ignore + + +def _single_input_case( + module_factory: Callable[[], torch.nn.Module], +) -> Callable[[], Tuple[torch.nn.Module, input_single]]: + def _create() -> Tuple[torch.nn.Module, input_single]: + return module_factory(), (torch.ones(2, 3),) + + return _create + + +def _dual_input_case( + module_factory: Callable[[], torch.nn.Module], +) -> Callable[[], Tuple[torch.nn.Module, input_double]]: + def _create() -> Tuple[torch.nn.Module, input_double]: + return module_factory(), (torch.zeros(2, 3), torch.full((2, 3), -2.0)) + + return _create + + +test_cases: dict[str, Callable[[], Tuple[torch.nn.Module, Tuple]]] = { + "two_in_two_out": _dual_input_case(WhileTwoInputsTwoOutputs), + "one_in_one_buffer_two_out": _single_input_case(WhileOneInputOneBufferTwoOutputs), + "additional_arg": _single_input_case(WhileAdditionalArg), + "two_in_one_captured_out": _single_input_case(WhileSingleCapturedOutput), +} + + +@common.parametrize( + "case", + test_cases, + xfails={ + "additional_arg": "Support not implemented.", + "two_in_one_captured_out": "When only one output is used, the second one is removed, which is not allowed in TOSA.", + }, +) +def test_while_loop_tosa_FP(case: Callable[[], Tuple[torch.nn.Module, Tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineFP[tuple]( + module, + example_inputs, + "torch.ops.higher_order.while_loop", + tosa_extensions=["cf"], + ) + pipeline.run() + + +@common.parametrize( + "case", + test_cases, + xfails={ + "additional_arg": "Support not implemented.", + "two_in_one_captured_out": "When only one output is used, the second one is removed, which is not allowed in TOSA.", + }, +) +def test_while_loop_tosa_INT(case: Callable[[], Tuple[torch.nn.Module, Tuple]]): + module, example_inputs = case() + pipeline = TosaPipelineINT[tuple]( + module, + example_inputs, + "torch.ops.higher_order.while_loop", + tosa_extensions=["cf"], + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + ArmTester.check_not, + pipeline.tester, + ["torch.ops.higher_order.while_loop"], + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_zeros.py b/backends/arm/test/ops/test_zeros.py index caee678282a..d9a885620d9 100644 --- a/backends/arm/test/ops/test_zeros.py +++ b/backends/arm/test/ops/test_zeros.py @@ -65,7 +65,10 @@ def test_zeros_tosa_INT(test_data: test_data_t): input_data(), ZerosAdd.aten_op, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -79,7 +82,10 @@ def test_zeros_u55_INT(test_data: test_data_t): ZerosAdd.aten_op, use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -92,8 +98,11 @@ def test_zeros_u85_INT(test_data: test_data_t): input_data(), ZerosAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") - pipeline.pop_stage("check.quant_nodes") + ) + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -139,5 +148,8 @@ def test_zeros_vgf_INT(test_data: test_data_t): ZerosAdd.aten_op, tosa_version="TOSA-1.0+INT", ) - pipeline.pop_stage("check.quant_nodes") + # Pop the quantization check stage if it exists as no + # quantization nodes will be present for int + fp inputs. + if pipeline.has_stage("check.quant_nodes"): + pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py index 7c7ad984e4c..5366e5453c1 100644 --- a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py @@ -90,7 +90,6 @@ def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT( aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -152,7 +151,6 @@ def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT( aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -214,7 +212,6 @@ def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT( aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -387,7 +384,6 @@ def test_convert_full_int64_dtype_to_int32_pass_tosa_INT( aten_ops_checks, exir_ops_checks, ) - pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py b/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py new file mode 100644 index 00000000000..eb395403e3f --- /dev/null +++ b/backends/arm/test/passes/test_convert_permute_singleton_to_view_pass.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm._passes import ConvertPermuteSingletonToViewPass +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] + + +class PermuteSingletonAxesModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 3, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 1, 3, 4),) + + +def test_convert_permute_singleton_to_view_applies(): + module = PermuteSingletonAxesModule() + pipeline = PassPipeline[input_t]( + module, + PermuteSingletonAxesModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() + + +class PermuteNonSingletonModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(0, 2, 1) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 3, 4),) + + +def test_convert_permute_singleton_to_view_skips_non_singleton(): + module = PermuteNonSingletonModule() + pipeline = PassPipeline[input_t]( + module, + PermuteNonSingletonModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() + + +class PermuteSameSizedNonSingletonModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(2, 1, 0) + + @staticmethod + def input() -> input_t: + return (torch.randn(2, 1, 2),) + + +def test_convert_permute_singleton_to_view_skips_same_sized_non_singleton(): + module = PermuteSameSizedNonSingletonModule() + pipeline = PassPipeline[input_t]( + module, + PermuteSameSizedNonSingletonModule.input(), + quantize=False, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ], + pass_list=[ConvertPermuteSingletonToViewPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_convert_to_clamp.py b/backends/arm/test/passes/test_convert_to_clamp.py index 5072af000b0..b54c177e52f 100644 --- a/backends/arm/test/passes/test_convert_to_clamp.py +++ b/backends/arm/test/passes/test_convert_to_clamp.py @@ -7,7 +7,7 @@ from typing import ClassVar, Dict, Tuple import torch -from executorch.backends.arm._passes.convert_to_clamp import ConvertToClampPass +from executorch.backends.arm._passes.convert_to_clamp_pass import ConvertToClampPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline diff --git a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py index 405c3d7ca8f..c4aebae2292 100644 --- a/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py +++ b/backends/arm/test/passes/test_decompose_avg_pool2d_pass.py @@ -6,7 +6,9 @@ from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d +from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( + DecomposeAvgPool2dPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -75,6 +77,6 @@ def test_decompose_avg_pool2d_tosa_MI(module: ModuleWithInputs) -> None: # After decomposition, we should still see avg_pool2d (transformed) "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, }, - pass_list=[DecomposeAvgPool2d], + pass_list=[DecomposeAvgPool2dPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int_pow_to_muls.py b/backends/arm/test/passes/test_decompose_int_pow_pass.py similarity index 92% rename from backends/arm/test/passes/test_convert_int_pow_to_muls.py rename to backends/arm/test/passes/test_decompose_int_pow_pass.py index bccde782f55..a9a74c633e1 100644 --- a/backends/arm/test/passes/test_convert_int_pow_to_muls.py +++ b/backends/arm/test/passes/test_decompose_int_pow_pass.py @@ -6,7 +6,7 @@ from typing import cast, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes import ConvertIntPowToMuls +from executorch.backends.arm._passes import DecomposeIntPowPass from executorch.backends.arm.test import common @@ -60,7 +60,7 @@ def get_inputs(self) -> input_t: @common.parametrize("data", test_data) -def test_convert_pow_to_muls(data: TestParam) -> None: +def test_decompose_int_pow(data: TestParam) -> None: module_with_inputs, nbr_muls = data module = cast(torch.nn.Module, module_with_inputs) pipeline = PassPipeline[input_t]( @@ -75,6 +75,6 @@ def test_convert_pow_to_muls(data: TestParam) -> None: "executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls, }, ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"], - pass_list=[ConvertIntPowToMuls], + pass_list=[DecomposeIntPowPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py index bd83bfc9a22..b926e15b92a 100644 --- a/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py +++ b/backends/arm/test/passes/test_decompose_linalg_vector_norm_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm._passes.decompose_linalg_vector_norm_pass import ( - DecomposeLinearVectorNormPass, + DecomposeLinalgVectorNormPass, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -65,7 +65,7 @@ def get_inputs(self) -> input_t: @common.parametrize("module", modules) def test_decompose_vector_norm_tosa_INT(module: ModuleWithInputs) -> None: """ - This test creates a PassPipeline that applies the DecomposeLinearVectorNormPass. + This test creates a PassPipeline that applies the DecomposeLinalgVectorNormPass. The expected primitive ops vary depending on the norm order: - p == 1: should decompose to ABS and SUM. - p == 2 (default): should decompose to MUL, SUM, and SQRT. @@ -102,6 +102,6 @@ def test_decompose_vector_norm_tosa_INT(module: ModuleWithInputs) -> None: ops_not_after_pass=[ "executorch_exir_dialects_edge__ops_aten_linarg_vector_norm_default", ], - pass_list=[DecomposeLinearVectorNormPass], + pass_list=[DecomposeLinalgVectorNormPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index 08bf960da7d..eb073265a63 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -6,7 +6,7 @@ from typing import cast, ClassVar, Dict, Protocol, Tuple import torch -from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass +from executorch.backends.arm._passes.fuse_batch_norm2d_pass import FuseBatchNorm2dPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline @@ -159,6 +159,6 @@ def test_fuse_batchnorm_tosa_FP(module: ModuleWithBatchNormAttrs) -> None: quantize=False, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[FuseBatchnorm2DPass], + passes_with_exported_program=[FuseBatchNorm2dPass], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 95492075c0d..deb017bf662 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( - ComputeConstantOpsAOT, + ComputeConstantOpsAOTPass, FuseConstantArgsPass, ) from executorch.backends.arm.test import common @@ -157,7 +157,10 @@ def test_fuse_const_ops_tosa_FP(module: ModuleWithFuseAttrs) -> None: ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() @@ -170,7 +173,10 @@ def test_fuse_const_ops_tosa_INT(module: ModuleWithFuseAttrs) -> None: quantize=True, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() @@ -183,7 +189,10 @@ def test_fuse_const_ops_tosa_BI_cat(module: ModuleWithFuseAttrs) -> None: quantize=True, ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, - passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], + passes_with_exported_program=[ + ComputeConstantOpsAOTPass, + FuseConstantArgsPass, + ], ) pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py index a7e80794015..ffe56e72691 100644 --- a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py +++ b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Dict, Tuple import torch from executorch.backends.arm._passes import FuseDuplicateUsersPass @@ -13,7 +13,12 @@ input_t = Tuple[torch.Tensor] # Input x -class FuseaAvgPool(torch.nn.Module): +class ModuleWithOps(torch.nn.Module): + ops_before_pass: Dict[str, int] + ops_after_pass: Dict[str, int] + + +class FuseaAvgPool(ModuleWithOps): ops_before_pass = { "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 3, } @@ -27,7 +32,7 @@ def forward(self, x): return self.avg(x) + self.avg(x) + self.avg(x) -class FuseAvgPoolChain(torch.nn.Module): +class FuseAvgPoolChain(ModuleWithOps): ops_before_pass = { "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6, } @@ -44,14 +49,14 @@ def forward(self, x): return first + second + third -modules = { +modules: Dict[str, ModuleWithOps] = { "fuse_avg_pool": FuseaAvgPool(), "fuse_avg_pool_chain": FuseAvgPoolChain(), } @common.parametrize("module", modules) -def test_fuse_duplicate_ops_FP(module: torch.nn.Module): +def test_fuse_duplicate_ops_FP(module: ModuleWithOps): pipeline = PassPipeline[input_t]( module=module, test_data=(torch.ones(1, 1, 1, 1),), diff --git a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py index 7c32cee8534..2461a0e833a 100644 --- a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py +++ b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py @@ -8,9 +8,13 @@ import torch from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineINT, +) input_t = Tuple[torch.Tensor, torch.Tensor] # weights, indices +input_t3 = Tuple[torch.Tensor, torch.LongTensor, torch.Tensor] class Int64InputModel(torch.nn.Module): @@ -44,3 +48,67 @@ def test_int64_model_tosa_FP(): ) pipeline.pop_stage(-1) # Do not compare output pipeline.run() + + +class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module): + aten_op = "torch.ops.aten.index_copy_.default" + + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): + return x.index_copy_(0, index, y) + + def get_inputs(self) -> input_t3: + return ( + torch.zeros(5, 3), + torch.LongTensor([0, 4, 2]), + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), + ) + + +def test_upcast_to_int64_for_index_copy_inplace_tosa_INT(): + module = UpcastToInt64ForIndexCopyInplaceModel() + pipeline = TosaPipelineINT[input_t3]( + module, + module.get_inputs(), + aten_op=module.aten_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +class UpcastToInt64ForIndexCopyModel(torch.nn.Module): + aten_op = "torch.ops.aten.index_copy.default" + + def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): + return x.index_copy(0, index, y) + + def get_inputs(self) -> input_t3: + return ( + torch.zeros(5, 3), + torch.LongTensor([0, 4, 2]), + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float), + ) + + +def test_upcast_to_int64_for_index_copy_tosa_INT(): + module = UpcastToInt64ForIndexCopyModel() + pipeline = TosaPipelineINT[input_t3]( + module, + module.get_inputs(), + aten_op=module.aten_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/quantizer/test_set_module_name.py b/backends/arm/test/quantizer/test_set_module_name.py new file mode 100644 index 00000000000..56131a83e86 --- /dev/null +++ b/backends/arm/test/quantizer/test_set_module_name.py @@ -0,0 +1,158 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, + is_annotated, + QuantizationConfig, + TOSAQuantizer, +) +from executorch.backends.arm.quantizer.quantization_config import QuantizationSpec +from executorch.backends.arm.tosa import TosaSpecification +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +DQ_PER_CHANNEL = torch.ops.quantized_decomposed.dequantize_per_channel.default +DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default +Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default + + +class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv0 = torch.nn.Conv2d( + 3, + 16, + kernel_size=4, + ) + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=3, bias=False) + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3) + + def forward(self, x): + x = self.conv0(x) + x = torch.sigmoid(x) + x = self.conv1(x) + x = torch.tanh(x) + x = self.conv2(x) + return x + + +test_inputs = (torch.randn(1, 3, 64, 64),) + + +def validate_per_tensor_quant(node: torch.fx.Node, qspec: QuantizationSpec): + _, _, zero_point, qmin, qmax, dtype = node.args + if qspec.qscheme == torch.per_tensor_symmetric: + assert ( + zero_point == 0 + ), f"Zero point {zero_point} is not zero for symmetric quantization" + assert ( + qmin == qspec.quant_min + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" + assert ( + qmax == qspec.quant_max + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" + + +def validate_per_channel_quant(node: torch.fx.Node, qspec: QuantizationSpec): + _, _, _, channel_axis, qmin, qmax, dtype = node.args + assert ( + channel_axis == qspec.ch_axis + ), f"Channel axis {channel_axis} does not match expected {qspec.ch_axis}" + assert ( + qmin == qspec.quant_min + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" + assert ( + qmax == qspec.quant_max + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" + + +def validate_input(input_node: torch.fx.Node, qspec: QuantizationSpec | None): + if qspec is None: + return + + per_channel = qspec.qscheme == torch.per_channel_symmetric + expected_dequant_op = DQ_PER_CHANNEL if per_channel else DQ_PER_TENSOR + assert ( + input_node.target == expected_dequant_op + ), f"Input node {input_node} is not quantized as expected" + if per_channel: + validate_per_channel_quant(input_node, qspec) + else: + validate_per_tensor_quant(input_node, qspec) + + +def validate_output(node: torch.fx.Node, qspec: QuantizationSpec | None): + if qspec is None: + return + users = list(node.users) + assert len(users) == 1, f"Node {node} should have exactly one user" + assert ( + users[0].target == Q_PER_TENSOR + ), f"Output node {users[0]} is not quantized as expected" + validate_per_tensor_quant(users[0], qspec) + + +def validate_node( + node: torch.fx.Node, quantization_config: QuantizationConfig | None +) -> None: + if quantization_config is None: + assert not is_annotated(node), f"Node {node} is unexpectedly annotated" + return + + assert is_annotated(node), f"Node {node} is not annotated" + input_qspec = quantization_config.get_input_act_qspec() + output_qspec = quantization_config.get_output_act_qspec() + weight_qspec = quantization_config.get_weight_qspec() + + if len(node.all_input_nodes) == 3: + input_node, weight_node, bias_node = node.all_input_nodes + bias_qspec = quantization_config.get_bias_qspec(node) + validate_input(bias_node, bias_qspec) + else: + input_node, weight_node = node.all_input_nodes + + validate_input(input_node, input_qspec) + validate_input(weight_node, weight_qspec) + validate_output(node, output_qspec) + + +def test_set_module_name() -> None: + model = ConvModel() + model.eval() + + # Set up quantizer with different configs for different modules + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + quantizer = TOSAQuantizer(tosa_spec) + int8_config = get_symmetric_quantization_config(is_per_channel=False) + a16w8_config = get_symmetric_a16w8_quantization_config() + # Set module-specific configurations but don't set global config to test that + # only specified modules are quantized + quantizer.set_module_name("conv0", int8_config) + quantizer.set_module_name("conv1", a16w8_config) + + # Export model + exported_model = torch.export.export(model, test_inputs) + + # Prepare, calibrate and convert model + prepared_model = prepare_pt2e(exported_model.module(), quantizer) + prepared_model(*test_inputs) + converted_model = convert_pt2e(prepared_model) + + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d"][0], + int8_config, + ) + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d_1"][0], + a16w8_config, + ) + validate_node( + [node for node in converted_model.graph.nodes if node.name == "conv2d_2"][0], + None, + ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index ae1fc136ce7..d7112bfa654 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -31,12 +31,13 @@ from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.arm.vgf.model_converter import find_model_converter_binary from executorch.exir import ExecutorchProgramManager, ExportedProgram from executorch.exir.lowered_backend_module import LoweredBackendModule from torch.fx.node import Node from torch.overrides import TorchFunctionMode -from tosa.TosaGraph import TosaGraph # type: ignore[import-untyped] +from tosa.TosaGraph import TosaGraph # type: ignore[import-not-found, import-untyped] logger = logging.getLogger(__name__) @@ -678,11 +679,15 @@ def corstone320_installed() -> bool: def model_converter_installed() -> bool: - cmd = ["model-converter", "--version"] + model_converter = find_model_converter_binary() + if model_converter is None: + return False + try: - _run_cmd(cmd, check=True) - except: + _run_cmd([model_converter, "--version"], check=True) + except Exception: return False + return True @@ -714,7 +719,9 @@ def assert_elf_path_exists(elf_path): ) -def get_elf_path(target_board: str, use_portable_ops: bool = False): +def get_elf_path(target_board: str, use_portable_ops: bool = False) -> str: + elf_path = "" + if target_board not in VALID_TARGET: raise ValueError(f"Unsupported target: {target_board}") @@ -729,14 +736,13 @@ def get_elf_path(target_board: str, use_portable_ops: bool = False): f"arm_semihosting_executor_runner_{portable_ops_str}{target_board}", "arm_executor_runner", ) - assert_elf_path_exists(elf_path) elif target_board == "vkml_emulation_layer": elf_path = os.path.join( f"arm_test/arm_executor_runner_{portable_ops_str}vkml", "executor_runner", ) - assert_elf_path_exists(elf_path) + assert_elf_path_exists(elf_path) return elf_path @@ -761,13 +767,13 @@ def run_tosa_graph( inputs_np = [torch_tensor_to_numpy(input_tensor) for input_tensor in inputs] if isinstance(tosa_version, Tosa_1_00): - import tosa_reference_model as reference_model # type: ignore[import-untyped] + import tosa_reference_model as reference_model # type: ignore[import-not-found, import-untyped] - debug_mode = "ALL" if logger.level <= logging.DEBUG else None + debug_mode = "ALL" if logger.getEffectiveLevel() <= logging.DEBUG else None outputs_np, status = reference_model.run( graph, inputs_np, - verbosity=_tosa_refmodel_loglevel(logger.level), + verbosity=_tosa_refmodel_loglevel(logger.getEffectiveLevel()), initialize_variable_tensor_from_numpy=True, debug_mode=debug_mode, ) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 01cafad13d0..ffb18043536 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -22,6 +22,7 @@ def define_arm_tests(): "ops/test_linear.py", "ops/test_mul.py", "ops/test_permute.py", + "ops/test_rsqrt.py", "ops/test_slice.py", "ops/test_sigmoid.py", "ops/test_sub.py", diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh index 4bc4fe0f06d..8c45e28376b 100755 --- a/backends/arm/test/test_arm_baremetal.sh +++ b/backends/arm/test/test_arm_baremetal.sh @@ -157,7 +157,6 @@ test_pytest_ethosu_fvp() { # Same as test_pytest but also sometime verify using test_pytest_ops_vkml() { # Same as test_pytest but also sometime verify using VKML runtime echo "${TEST_SUITE_NAME}: Run pytest operator tests with VKML runtime" - backends/arm/scripts/build_executorch.sh backends/arm/test/setup_testing_vkml.sh pytest --verbose --color=yes --numprocesses=auto --durations=10 backends/arm/test/ \ @@ -190,11 +189,11 @@ test_run_vkml() { # End to End model tests using run.sh echo "${TEST_SUITE_NAME}: Test VKML" out_folder="arm_test/test_run" - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=add --output=${out_folder}/runner - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=mul --output=${out_folder}/runner + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=add --output=${out_folder}/runner --bundleio + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=mul --output=${out_folder}/runner --bundleio - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qadd --output=${out_folder}/runner - examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qops --output=${out_folder}/runner + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qadd --output=${out_folder}/runner --bundleio + examples/arm/run.sh --et_build_root=${out_folder} --target=vgf --model_name=qops --output=${out_folder}/runner --bundleio echo "${TEST_SUITE_NAME}: PASS" } @@ -254,8 +253,8 @@ test_models_vkml() { # End to End model tests using model_test.py # VKML echo "${TEST_SUITE_NAME}: Test target VKML" - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=mv2 - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --no_quantize --model=mv2 + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet18 --extra_runtime_flags="--bundleio_atol=0.2 --bundleio_rtol=0.2" + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=vgf --model=resnet50 --extra_runtime_flags="--bundleio_atol=0.2 --bundleio_rtol=0.2" echo "${TEST_SUITE_NAME}: PASS" } @@ -390,9 +389,19 @@ test_memory_allocation() { --require "model_pte_program_size" "<= 3000 B" \ --require "method_allocator_planned" "<= 64 B" \ --require "method_allocator_loaded" "<= 1024 B" \ - --require "method_allocator_input" "<= 4 B" \ + --require "method_allocator_input" "<= 16 B" \ --require "Total DRAM used" "<= 0.06 KiB" echo "${TEST_SUITE_NAME}: PASS" } +test_undefinedbehavior_sanitizer() { + echo "${TEST_SUITE_NAME}: Test ethos-u executor_runner with UBSAN" + + mkdir -p arm_test/test_run + # Ethos-U85 + echo "${TEST_SUITE_NAME}: Test target Ethos-U85" + examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=examples/arm/example_modules/add.py --build_type=UndefinedSanitizer + echo "${TEST_SUITE_NAME}: PASS" +} + ${TEST_SUITE} diff --git a/backends/arm/test/test_model.py b/backends/arm/test/test_model.py index 5dc11e12a08..04972856044 100755 --- a/backends/arm/test/test_model.py +++ b/backends/arm/test/test_model.py @@ -67,9 +67,15 @@ def get_args(): parser.add_argument( "--extra_flags", required=False, - default=None, + default="", help="Extra cmake flags to pass the when building the executor_runner", ) + parser.add_argument( + "--extra_runtime_flags", + required=False, + default="", + help="Extra runtime flags to pass the final runner/executable", + ) parser.add_argument( "--timeout", required=False, @@ -130,20 +136,18 @@ def build_pte( no_intermediate: bool, no_quantize: bool, ): - pte_file_ending = "pte" command_list = [ "python3", "-m", "examples.arm.aot_arm_compiler", "--delegate", + "--bundleio", f"--model_name={model_name}", f"--target={target}", f"--output={build_output}", ] if "vgf" != target: - pte_file_ending = "bpte" - command_list.append("--bundleio") command_list.append(f"--system_config={system_config}") command_list.append(f"--memory_mode={memory_mode}") @@ -155,6 +159,7 @@ def build_pte( run_external_cmd(command_list) + pte_file_ending = "bpte" pte_file = os.path.join( output, f"{model_name}_arm_delegate_{args.target}.{pte_file_ending}" ) @@ -218,6 +223,7 @@ def build_vkml_runtime( os.path.join(script_path, "build_executor_runner_vkml.sh"), f"--et_build_root={et_build_root}", "--etdump", + "--bundleio", "--build_type=Release", f"--extra_build_flags=-DET_DUMP_OUTPUT=OFF {extra_flags}", f"--output={build_path}", @@ -228,13 +234,14 @@ def build_vkml_runtime( return runner -def run_vkml(script_path: str, pte_file: str, runner_build_path: str): +def run_vkml(script_path: str, pte_file: str, runner_build_path: str, extra_flags: str): run_external_cmd( [ "bash", os.path.join(script_path, "run_vkml.sh"), f"--model={pte_file}", f"--build_path={runner_build_path}", + f"--optional_flags={extra_flags}", ] ) @@ -297,7 +304,7 @@ def run_vkml(script_path: str, pte_file: str, runner_build_path: str): ) start_time = time.perf_counter() - run_vkml(script_path, pte_file, build_path) + run_vkml(script_path, pte_file, build_path, args.extra_runtime_flags) end_time = time.perf_counter() print( f"[Test model: {end_time - start_time:.2f} s] Tested VKML runner: {vkml_runner}" diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index 9bea6337655..3bcac603a9e 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -312,11 +312,8 @@ def dump_error_output( if __name__ == "__main__": - import sys - logging.basicConfig(stream=sys.stdout, level=logging.INFO) - - """ This is expected to produce the example output of print_diff""" + """This is expected to produce the example output of print_diff""" torch.manual_seed(0) a = torch.rand(3, 3, 2, 2) * 0.01 b = a.clone().detach() diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 1d9ee42c19e..d617a424b33 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -48,7 +48,9 @@ dump_error_output, print_error_diffs, ) +from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize from executorch.backends.arm.test.tester.serialize import Serialize + from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.mapping import extract_tensor_meta @@ -313,7 +315,7 @@ def quantize( # Same stage type as parent but exposed via module alias if quantize_stage is None: quantizer = create_quantizer(self.compile_spec) - quantize_stage = tester.Quantize( + quantize_stage = Quantize( quantizer, get_symmetric_quantization_config(), ) @@ -832,7 +834,7 @@ def _dump_str(to_print: str, path_to_dump: Optional[str] = None): with open(path_to_dump, "a") as fp: fp.write(to_print) else: - logger.info(to_print) + print(to_print) def _format_dict(to_print: dict, print_table: bool = True) -> str: diff --git a/backends/arm/test/tester/quantize.py b/backends/arm/test/tester/quantize.py new file mode 100644 index 00000000000..18ecd401efe --- /dev/null +++ b/backends/arm/test/tester/quantize.py @@ -0,0 +1,43 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from executorch.backends.arm.quantizer import TOSAQuantizer +from executorch.backends.test.harness.stages.quantize import Quantize + +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + DuplicateDynamicQuantChainPass, +) + +from torch.export import export + + +class ArmQuantize(Quantize): + + def run( + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] + ) -> None: + assert inputs is not None + if self.is_qat: + artifact.train() + captured_graph = export(artifact, inputs, strict=True).module() + + if not isinstance(self.quantizer, TOSAQuantizer): + raise ValueError("ArmQuantizer can only run with TOSAQuantizer.") + + if self.calibration_samples is not None: + converted = self.quantizer.quantize_with_submodules( + captured_graph, self.calibration_samples, bool(self.is_qat) # type: ignore + ) + else: + converted = self.quantizer.quantize_with_submodules( + captured_graph, [inputs], bool(self.is_qat) + ) + + DuplicateDynamicQuantChainPass()(converted) + + self.converted_graph = converted diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index f3f5ab390e5..824f13417b2 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -25,18 +25,19 @@ from executorch.backends.arm.quantizer import ( EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, TOSAQuantizer, VgfQuantizer, ) from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses + +from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, ) - -from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.pass_base import ExportPass from torch._export.pass_base import PassType @@ -52,6 +53,13 @@ def _require_tosa_version() -> str: return version +def _has_quantizable_inputs(test_data: T) -> bool: + for data in test_data: + if isinstance(data, torch.Tensor) and data.is_floating_point(): + return True + return False + + class PipelineStage: """Container for a pipeline stage (callable plus arguments).""" @@ -366,9 +374,15 @@ def __init__( ) quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose 16A8W quantization config when int16 extension is requested + if "int16" in tosa_extensions: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -384,30 +398,32 @@ def __init__( ) self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) if run_on_tosa_ref_model: self.add_stage( @@ -527,6 +543,7 @@ def __init__( run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, + a16w8_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, @@ -539,9 +556,15 @@ def __init__( tosa_debug_mode=tosa_debug_mode, ) quantizer = EthosUQuantizer(compile_spec) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose int8 or int16 activation quantization + if a16w8_quantization: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -557,30 +580,32 @@ def __init__( self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) if run_on_fvp: self.add_stage(self.tester.serialize) @@ -618,6 +643,7 @@ def __init__( run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, + a16w8_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str | None = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, @@ -630,9 +656,15 @@ def __init__( tosa_debug_mode=tosa_debug_mode, ) quantizer = EthosUQuantizer(compile_spec) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose int8 or int16 activation quantization + if a16w8_quantization: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -648,30 +680,32 @@ def __init__( self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) if run_on_fvp: self.add_stage(self.tester.serialize) @@ -772,7 +806,10 @@ def __init__( self.add_stage(self.tester.check_count, ops_after_pass, suffix="after") if ops_not_after_pass: self.add_stage(self.tester.check_not, ops_not_after_pass, suffix="after") - self.add_stage(self.tester.run_method_and_compare_outputs) + self.add_stage( + self.tester.run_method_and_compare_outputs, + inputs=self.test_data, + ) def run(self): with TosaLoweringContext(self.tosa_spec): @@ -978,30 +1015,32 @@ def __init__( self.add_stage(self.tester.quantize, quant_stage, pos=0) - self.add_stage_after( - "quantize", - self.tester.check, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) - remove_quant_nodes_stage = ( "to_edge_transform_and_lower" if use_to_edge_transform_and_lower else "partition" ) - self.add_stage_after( - remove_quant_nodes_stage, - self.tester.check_not, - [ - "torch.ops.quantized_decomposed.dequantize_per_tensor.default", - "torch.ops.quantized_decomposed.quantize_per_tensor.default", - ], - suffix="quant_nodes", - ) + + if _has_quantizable_inputs(test_data): + # only add stages if we have quantizable input + self.add_stage_after( + "quantize", + self.tester.check, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) + self.add_stage_after( + remove_quant_nodes_stage, + self.tester.check_not, + [ + "torch.ops.quantized_decomposed.dequantize_per_tensor.default", + "torch.ops.quantized_decomposed.quantize_per_tensor.default", + ], + suffix="quant_nodes", + ) else: self.add_stage_after( "export", diff --git a/backends/arm/tosa/TARGETS b/backends/arm/tosa/TARGETS index 51919025591..d0f7a743f53 100644 --- a/backends/arm/tosa/TARGETS +++ b/backends/arm/tosa/TARGETS @@ -11,20 +11,6 @@ runtime.python_library( ":specification", ], ) -runtime.python_library( - name = "quant_utils", - srcs = [ - "quant_utils.py", - ], - deps = [ - "fbsource//third-party/pypi/numpy:numpy", - "fbsource//third-party/tosa_tools:serializer", - "fbsource//third-party/tosa_tools:tosa", - "//executorch/backends/arm:constants", - ":mapping", - "//executorch/exir/dialects:lib", - ], -) runtime.python_library( name = "specification", srcs = [ @@ -41,7 +27,6 @@ runtime.python_library( "utils.py", ], deps = [ - ":quant_utils", "//executorch/backends/arm/operators:node_visitor", ], ) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 1631526e360..99fcadac081 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -2,13 +2,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide TOSA backend entry points for the Arm ExecuTorch integration. +Implement the Ahead-of-Time (AoT) preprocessing path that lowers an +``ExportedProgram`` to a TOSA flatbuffer using Arm's lowering pipeline. Use +this module either as a standalone backend that produces a TOSA artifact or as +part of a composed pipeline for hardware backends that consume TOSA as an +intermediate form. + +Use ``TOSABackend.preprocess`` to return the serialized TOSA flatbuffer that +subsequent stages (for example, JIT or hardware-specific compilers) consume. + +""" -# -# Main implementation of AoT flow to partition and preprocess for Arm target -# backends. Converts via TOSA as an intermediate form supported by AoT and -# JIT compiler flows. -# import logging import tempfile from collections import deque @@ -28,7 +34,7 @@ from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -37,14 +43,19 @@ def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: - """ - Returns dictionary: node name -> external ids + """Assign deterministic output IDs to nodes reachable from graph outputs. + + Args: + ep_graph (Graph): FX graph produced by export preprocessing. + + Returns: + dict[str, int]: Mapping from node name to external output index. - Assign id to an output node of the model so we can trace it. """ node2external_id = {} def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): + """Walk producer graph from ``start_nodes`` and record external IDs.""" q = deque(start_nodes) while q: n = q.popleft() @@ -65,7 +76,19 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]): + """Reorder graph outputs to match ascending external IDs. + + Args: + graph_module (GraphModule): Graph to reorder in place. + node_to_id_map (dict[str, int]): Mapping from node name to output index. + + Returns: + GraphModule: Updated graph module with deterministic output ordering. + + """ + def _external_id(n: Node, node_2_id, fallback: int) -> int: + """Return the external ID for ``n`` or ``fallback`` when absent.""" return node_2_id.get(n.name, fallback) out_node = graph_module.graph.output_node() @@ -74,6 +97,7 @@ def _external_id(n: Node, node_2_id, fallback: int) -> int: # sort nodes by the key that is id def _sort_key(t: Node) -> int: + """Key function that orders outputs by external ID or position.""" return _external_id(t, node_to_id_map, next(_counter)) orig_ord = tuple(sorted(out_list, key=_sort_key)) @@ -89,7 +113,16 @@ def _sort_key(t: Node) -> int: def arm_get_first_delegation_tag(graph_module) -> str: - """Get the first delegation tag from the graph_module or return empty string.""" + """Return the first delegation tag discovered in the FX graph. + + Args: + graph_module (GraphModule): Module produced by Arm partitioning. + + Returns: + str: First non-empty delegation tag or an empty string when no tag is + recorded. + + """ for node in graph_module.graph.nodes: tag = node.meta.get("delegation_tag") if tag: @@ -101,14 +134,26 @@ def arm_get_first_delegation_tag(graph_module) -> str: @final class TOSABackend(BackendDetails): - """ - BackendDetails subclass for lowering to TOSA. - Is used either by itself to get to a TOSA representation, or with composition - to be used as a separate step to target TOSA compliant hardware. + """Provide a backend for lowering programs to TOSA. + + Use this class standalone to produce a TOSA representation, or as part of a + composed pipeline for hardware backends that consume TOSA. + """ @staticmethod def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]): + """Convert an exported program using the provided compile specs. + + Args: + edge_program (ExportedProgram): Program generated by Torch export. + compile_specs (List[CompileSpec]): Raw compile specifications from + ``executorch.apply_backend``. + + Returns: + PreprocessResult: Result containing serialized TOSA bytes. + + """ return TOSABackend._preprocess( edge_program, TosaCompileSpec.from_list(compile_specs) ) @@ -118,6 +163,31 @@ def _preprocess( # noqa: C901 edge_program: ExportedProgram, compile_spec: TosaCompileSpec, ) -> PreprocessResult: + """Lower an exported program to a TOSA flatbuffer. + + Apply Arm transformation passes to ``edge_program``, then walk the + transformed FX graph to emit a TOSA graph via the serializer. When + requested in ``compile_spec``, write additional debug artifacts. + + Args: + edge_program (ExportedProgram): Program to lower to TOSA. + compile_spec (TosaCompileSpec): Backend options. Recognized keys: + - output_format: Must be "tosa". + - tosa_spec: Target TOSA version/capabilities. + - debug_artifact_path: Directory for debug outputs. + - compile_flags: Optional backend flags. + - dump_debug_info: Enable extra debug JSON dump. + + Returns: + PreprocessResult: Result containing processed_bytes with the + serialized TOSA flatbuffer. + + Raises: + ValueError: If output_format is not "tosa" or the TOSA + specification is missing from compile_spec. + RuntimeError: If an unsupported FX node type is encountered. + + """ # if a debug/test build capture output files from TOSA stage artifact_path = compile_spec.get_intermediate_path() tosa_spec = compile_spec.tosa_spec @@ -191,11 +261,29 @@ def _preprocess_module( # noqa: C901 tosa_graph: ts.TosaSerializer, debug_hook: DebugHook | None, submodule_name: str | None = None, + containing_graph_module: GraphModule | None = None, ): - """Convert 'graph_module' to a tosa_graph""" + """Convert an FX ``graph_module`` to TOSA serializer calls. + + Args: + graph_module (GraphModule): Module to lower recursively. + edge_program (ExportedProgram): Original exported program. + compile_spec (TosaCompileSpec): Backend options with TOSA settings. + tosa_graph (ts.TosaSerializer): Serializer receiving operators. + debug_hook (DebugHook | None): Optional debug instrumentation. + submodule_name (str | None): Name used when visiting nested blocks. + + Raises: + RuntimeError: If an FX node with an unsupported op kind is found. + + """ tosa_spec = compile_spec.tosa_spec + output_node = graph_module.graph.output_node() + if isinstance(output_node.args[0], Node): + output_node.update_arg(0, [output_node.args[0]]) node_to_id_map = _annotate_external_ids(graph_module.graph) artifact_path = compile_spec.get_intermediate_path() + output_order_workaround = compile_spec.get_output_order_workaround() # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager @@ -208,7 +296,12 @@ def _preprocess_module( # noqa: C901 from executorch.backends.arm.operators.node_visitor import get_node_visitors node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook) - graph_module = _sort_outputs(graph_module, node_to_id_map) + + if output_order_workaround: + logger.debug("Re-sorting outputs during TOSA lowering.") + graph_module = _sort_outputs(graph_module, node_to_id_map) + else: + logger.debug("No re-sorting outputs (workaround) during TOSA lowering.") if submodule_name is not None: tosa_graph.startRegion(submodule_name) @@ -223,11 +316,32 @@ def _preprocess_module( # noqa: C901 if node.op == "call_function": process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": - if len(node.users) == 0: + if len(node.users) == 0 and submodule_name is None: + # In top level module, we don't need to handle unused placeholders. + # In submodules, we do need to handle them to preserve call signature. continue - process_placeholder(node, tosa_graph, edge_program, tosa_spec) + process_placeholder( + node, + tosa_graph, + edge_program, + containing_graph_module, + tosa_spec, + ) elif node.op == "output": process_output(node, tosa_graph, tosa_spec) + elif node.op == "get_attr": + attr = getattr(graph_module, str(node.target), None) + if attr is None: + raise RuntimeError( + "get_attr node is not targeting anything in graph module." + ) + if not isinstance(attr, GraphModule): + raise RuntimeError( + "get_attr node is not targeting a GraphModule." + ) + + # If the above conditions are ok, we don't need to handle this node here. + # Only the string value of node.target is important. else: # This will only happen if an unpartitioned graph is passed without # any checking of compatibility. @@ -237,7 +351,7 @@ def _preprocess_module( # noqa: C901 raise # Recursively preprocess controlflow submodules. - for name, submodule, _ in get_control_flow_submodules(graph_module): + for name, submodule, _ in get_cond_while_submodules(graph_module): TOSABackend._preprocess_module( submodule, edge_program, @@ -245,23 +359,27 @@ def _preprocess_module( # noqa: C901 tosa_graph, debug_hook, submodule_name=name, + containing_graph_module=graph_module, ) @staticmethod def filter_tosa_compile_specs( compile_spec: ArmCompileSpec, ) -> TosaCompileSpec: - """ - Filter out the CompileSpec elements relevant for the TOSA backend. - This is needed to compose a backend targetting hardware IP with the - TOSABackend, since we first want to use the TOSABackend to generate - the TOSA flatbuffer representation as an intermediate step. The TOSA - flatbuffer can then be consumed by the backend targetting specific - hardware. - """ + """Extract the TOSA-specific settings from a composite compile spec. + + Args: + compile_spec (ArmCompileSpec): Compile specification that may + include both TOSA and hardware-specific options. + Returns: + TosaCompileSpec: TOSA-only specification ready for + ``TOSABackend.preprocess``. + + """ return ( TosaCompileSpec(compile_spec.tosa_spec) .dump_intermediate_artifacts_to(compile_spec.get_intermediate_path()) .dump_debug_info(compile_spec.tosa_debug_mode) + .set_output_order_workaround(compile_spec.output_order_workaround) ) diff --git a/backends/arm/tosa/compile_spec.py b/backends/arm/tosa/compile_spec.py index 39403c867d7..98671031e3d 100644 --- a/backends/arm/tosa/compile_spec.py +++ b/backends/arm/tosa/compile_spec.py @@ -8,12 +8,22 @@ class TosaCompileSpec(ArmCompileSpec): + """Arm-specific compile spec capturing TOSA serializer requirements.""" + def __init__(self, tosa_spec: TosaSpecification | str): + """Normalize and store the provided TOSA specification. + + Args: + tosa_spec (TosaSpecification | str): Target spec object or version + string supported by :meth:`TosaSpecification.create_from_string`. + + """ if isinstance(tosa_spec, str): tosa_spec = TosaSpecification.create_from_string(tosa_spec) self._set_compile_specs(tosa_spec, []) def validate(self): + """Ensure that no unsupported compiler flags were supplied.""" if len(self.compiler_flags) != 0: raise ValueError( f"TosaCompileSpec can't have compiler flags, got {self.compiler_flags}" @@ -22,4 +32,5 @@ def validate(self): @classmethod def get_output_format(cls) -> str: + """Return the artifact format emitted by this compile spec.""" return "tosa" diff --git a/backends/arm/tosa/dialect/lib.py b/backends/arm/tosa/dialect/lib.py index 4a807d682dc..ed26a21a297 100644 --- a/backends/arm/tosa/dialect/lib.py +++ b/backends/arm/tosa/dialect/lib.py @@ -15,6 +15,17 @@ def register_tosa_dialect_op(op_schema, func) -> Callable: + """Register a TOSA dialect operator with the backend op library. + + Args: + op_schema (str): Operator schema without namespace or overload name. + func (Callable): Fake implementation used for registration. + + Returns: + Callable: Backend dialect operator handle exposed via ``exir_ops`` and + marked ``not_callable`` for runtime use. + + """ if tosa_lib.ns not in _BACKEND_OP_LIB: _BACKEND_OP_LIB.append(tosa_lib.ns) @@ -43,6 +54,7 @@ def register_tosa_dialect_op(op_schema, func) -> Callable: # the op doesn't need to be callable. This can be changed in the future if needed to support # execution of TOSA ops directly. def not_callable(): + """Raise when the dialect op handle is invoked at runtime.""" raise RuntimeError("TOSA dialect op is not callable") op.__equvalent_callable__ = not_callable @@ -51,11 +63,22 @@ def not_callable(): class TosaValueError(ValueError): + """Error type that annotates failures with the originating TOSA op.""" + def __init__(self, message="A TOSA value error occurred", *args, op=None): + """Initialise the error with optional operator metadata. + + Args: + message (str): Human-readable error message. + *args: Additional arguments forwarded to ``ValueError``. + op: Optional operator identifier included in the string output. + + """ super().__init__(message, *args) self.op = op def __str__(self): + """Return the base message, appending the operator when provided.""" base_message = super().__str__() if self.op is not None: return f"{base_message} (TOSA op: {self.op})" diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index 1f976d0f5e0..b40b1f74a75 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -43,6 +43,12 @@ def RESIZE( ) bilinear = resize_mode == "bilinear" output_dtype = torch.int32 if bilinear else torch.int8 + elif x.dtype == torch.int16: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"Context TOSA spec {tosa_spec} doesn't support int16", op="RESIZE" + ) + output_dtype = x.dtype elif x.dtype in (torch.float16, torch.float32): if not tosa_spec.support_float(): raise TosaValueError( diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 5162d2c6a53..1b32525eed2 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -2,14 +2,14 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Provide PyTorch-to-TOSA mapping helpers. -Use these utilities to translate PyTorch dtypes and FX node metadata into -the TOSA serializer types and shapes used during initial compilation. +Use these utilities to translate PyTorch dtypes and FX node metadata into the +TOSA serializer types and shapes used during initial compilation. """ +import operator from enum import Enum from typing import Any, Optional, Sequence @@ -33,18 +33,27 @@ class TosaSpecialDtype(Enum): - """ - Special TOSA data types that are not natively supported in PyTorch, to be - used in specific scenarios as a value in the key from meta_key(). - """ + """Special TOSA dtypes not natively expressed in PyTorch.""" INT48 = ts.DType.INT48 def get_tosa_dtype(self) -> ts.DType: + """Return the underlying ``ts.DType`` enumerant. + + Returns: + ts.DType: Serializer dtype associated with the enum entry. + + """ return self.value @staticmethod def meta_key() -> str: + """Return the FX ``meta`` key that stores special dtypes. + + Returns: + str: Metadata key used to encode :class:`TosaSpecialDtype`. + + """ return "tosa_special_dtype" @@ -56,7 +65,7 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: tosa_spec (TosaSpecification): Active spec (reserved for future checks). Returns: - Any: Matching ``ts.DType`` enum value. + ts.DType: Matching serializer dtype. Raises: ValueError: If the dtype is unsupported or unknown. @@ -94,8 +103,8 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification): tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping. Returns: - tuple: ``(dtype, shape, dim_order)`` where ``dtype`` is ``ts.DType``, - ``shape`` is ``Tuple[int, ...]``, and ``dim_order`` is ``Tuple[int, ...]``. + tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing + tensor dtype, shape, and dimension order. Raises: ValueError: If ``meta['val']`` is not a ``FakeTensor``. @@ -129,14 +138,16 @@ class TosaArg: consistent structure suitable for TOSA serialization. Attributes: - name (str): Node name when argument is a ``torch.fx.Node``; empty otherwise. + name (str): Node name when argument is a ``torch.fx.Node``; empty + otherwise. dtype (ts.DType | None): Inferred dtype when available. shape (tuple[int, ...] | None): Inferred shape when available. - dim_order (tuple[int, ...] | None): Dimension order, defaulting to ``range(len(shape))``. + dim_order (tuple[int, ...] | None): Dimension order, defaulting to + ``range(len(shape))``. special (list | None): Captured list when the argument is a sequence. - number (float | int | None): Captured numeric value when given. + number (float | int | None): Captured numeric value when provided. tosa_spec (TosaSpecification): Active specification used for mapping. - + multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise. """ def __process_node(self, argument: torch.fx.Node): @@ -146,22 +157,34 @@ def __process_node(self, argument: torch.fx.Node): argument (torch.fx.Node): FX node to inspect. """ - self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "") - output_dtype, self.shape, self.dim_order = extract_tensor_meta( - argument.meta, self.tosa_spec - ) - - # Handle special case of types not representable in torch (i.e. i48_t) - if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): - output_dtype = special_type.get_tosa_dtype() + suffix = argument.meta.get(TOSA_TENSOR_NAME_META, "") + self.name = argument.name + suffix - self.dtype = output_dtype + if "val" in argument.meta: + output_dtype, self.shape, self.dim_order = extract_tensor_meta( + argument.meta, self.tosa_spec + ) + # Handle special case of types not representable in torch (i.e. i48_t) + if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): + output_dtype = special_type.get_tosa_dtype() + + self.dtype = output_dtype + + # If all users of the node are getitems, node visitors should connect the output of this node directly to the getitem tensors. + # Add a new attribute 'multiple_output_names' instead of making 'name' a list to avoid ambiguity regarding the type of 'name'. + # Make name of the output is the first getitem since we in most cases only handle that output. + users = list(argument.users) + if len(users) > 0 and all(user.target == operator.getitem for user in users): + self.multiple_output_names: list = [user.name + suffix for user in users] + self.name = self.multiple_output_names[0] + else: + self.multiple_output_names = [] def __process_list(self, argument): """Capture a sequence argument as ``special``. Args: - argument (Sequence): Sequence to store. + argument (Sequence[Any]): Sequence to store. """ self.special: list = list(argument) @@ -181,10 +204,13 @@ def __init__( """Initialize the argument wrapper and populate fields. Args: - argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, ``float``, ``torch.dtype``, or ``None``. - tosa_spec (Optional[TosaSpecification]): Active specification; required. + argument (Any): One of ``torch.fx.Node``, ``Sequence``, ``int``, + ``float``, ``torch.dtype``, or ``None``. + tosa_spec (Optional[TosaSpecification]): Active specification; + required for metadata extraction. Raises: + ValueError: If ``tosa_spec`` is missing or has the wrong type. RuntimeError: If ``argument`` is of an unsupported type. """ @@ -243,4 +269,6 @@ def __repr__(self): attrs.append(f"number={self.number!r}") if hasattr(self, "tosa_spec") and self.tosa_spec is not None: attrs.append(f"tosa_spec={self.tosa_spec!r}") + if hasattr(self, "names"): + attrs.append(f"names={self.multiple_output_names!r}") return f"{self.__class__.__name__}({', '.join(attrs)})" diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 83294369ae7..3fd88b330c2 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -2,7 +2,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """Provide a partitioner for delegating subgraphs to the TOSA backend. Implement logic to identify and tag regions of an ``ExportedProgram`` that can @@ -11,6 +10,7 @@ - Partition graphs based on operator support and additional checks. - Prune trivial no-op partitions that would lower to empty TOSA graphs. - Tag constant data and report reasons for rejected nodes. + """ import logging @@ -22,6 +22,7 @@ from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) + from executorch.backends.arm.common.type import ensure_type from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -36,7 +37,7 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition @@ -110,8 +111,8 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.aten.expand_copy.default: return False else: - multiples = calculate_multiples(node.args) - return all(m == 1 for m in multiples) + multiples, changes_rank = calculate_multiples(node.args) + return all(m == 1 for m in multiples) and not changes_rank def is_partitioned( @@ -141,6 +142,7 @@ def reject_partition( partition (object): Proposed partition object from the capability partitioner. reporter (WhyNoPartitionReporter): used to report why nodes were rejected. + """ for node in partition.nodes: if "delegation_tag" in node.meta: @@ -157,6 +159,7 @@ class TOSAPartitioner(Partitioner): Construct this partitioner for compile specs targeting TOSA. The partition algorithm uses capability checks and optional additional operator-support rules to tag nodes with a delegation tag per subgraph. + """ def __init__( @@ -190,19 +193,21 @@ def _tag_module( # noqa reporter: WhyNoPartitionReporter, tag_iterator: count | None = None, ) -> set[str]: - """Tag nodes in a module, possibly a submodule, from the containing program. + """Tag nodes in a module or submodule from the containing program. Args: module: A GraphModule from `containing_program` to tag nodes in. containing_program: The ExportedProgram that contains the module. reporter: A reporter to report why nodes were rejected. + Returns: A set of strings with the partition tags. + """ tags: set[str] = set() if tag_iterator is None: tag_iterator = count(0) - for _, submodule, _ in get_control_flow_submodules(module): + for _, submodule, _ in get_cond_while_submodules(module): submodule_tags = self._tag_module( submodule, containing_program, reporter, tag_iterator ) @@ -316,7 +321,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: tagged_exported_program=exported_program, partition_tags=partition_tags ) - def ops_to_not_decompose( + def ops_to_not_decompose( # noqa: C901 self, ep: ExportedProgram, ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -335,19 +340,31 @@ def ops_to_not_decompose( function that returns True when an op should not be decomposed. """ - ops_to_not_decompose_if_quant_op = [ + ops_to_not_decompose_if_quant_op = { + torch.ops.aten.eye.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.linear.default, - ] + torch.ops.aten.linspace.default, + } + ops_to_not_decompose_if_fp = { + torch.ops.aten.eye.default, + torch.ops.aten.logit.default, + torch.ops.aten.linear.default, + torch.ops.aten.linspace.default, + } + ops_to_not_decompose_always = { + torch.ops.aten.logit.default, + } + ops_to_not_decompose_if_integer = { + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + } def filter_fn(node: torch.fx.Node) -> bool: - """Return True to keep selected ops intact inside quantized regions. - - The predicate holds when the target is in - ``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are - quantize/dequantize ops, indicating a quantized activation that - should not be decomposed. + """Filter function applied to ops in 'ops_to_not_decompose'. + Returns True if the op should not be decomposed. + If this function returns True, the partitioner *must* accept the node, or the lowering fails. Args: node (torch.fx.Node): FX node to evaluate. @@ -356,6 +373,12 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ + if ( + self.tosa_spec.support_float() + and node.target in ops_to_not_decompose_if_fp + ): + return True + dq = ( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, @@ -389,16 +412,43 @@ def filter_fn(node: torch.fx.Node) -> bool: ): correct_output_quant = True - return correct_input_quant and correct_output_quant + if correct_input_quant and correct_output_quant: + return True - # By default, do not decompose the operator - return True + if node.target in ops_to_not_decompose_if_integer: + # We only want to tag nodes as do_not_decompose if we are sure that + # we can partition them. We partition them if one or more of the + # following is true: + # 1. The node outputs an integer type. + # 2. All users cast the output to an integer type. - ops_to_not_decompose = [ - torch.ops.aten.eye.default, - torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, - ] + ops_to_not_decompose_if_quant_op + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point and not dtype.is_complex: + return True + + output_nodes = node.users + for user in output_nodes: + if user.target != torch.ops.aten.to.dtype: + return False + else: + cast_dtype = get_first_fake_tensor(user).dtype + if cast_dtype.is_complex or cast_dtype.is_floating_point: + return False + return True + + if node.target in ops_to_not_decompose_if_fp: + if self.tosa_spec.support_float(): + return True + if node.target in ops_to_not_decompose_always: + return True + return False + + ops_to_not_decompose = list( + ops_to_not_decompose_always + | ops_to_not_decompose_if_quant_op + | ops_to_not_decompose_if_fp + | ops_to_not_decompose_if_integer + ) if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py deleted file mode 100644 index b3840c6ab1c..00000000000 --- a/backends/arm/tosa/quant_utils.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -# Utility functions for TOSA quantized lowerings - -import math - -from typing import Any, Tuple - -import tosa_serializer as ts - - -# TOSA uses the RESCALE operation to scale between values with differing precision. -# The RESCALE operator is defined using an integer multiply, add, and shift. -# This utility function is for calculating the multiplier and shift given a scale. -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling -def _compute_multiplier_and_shift( - scales: list[float], scaleWidth: int = 32 -) -> Tuple[list[int], list[int]]: - if scaleWidth == 16: - offset = 15 - elif scaleWidth == 32: - offset = 31 - else: - raise ValueError( - f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." - ) - - multipliers = [] - shifts = [] - for scale in scales: - mantissa, exponent = math.frexp(scale) - shift = exponent - - const_2_power_15_or_31 = 1 << offset - shifted_mantissa = round(mantissa * const_2_power_15_or_31) - - assert ( - shifted_mantissa <= const_2_power_15_or_31 - ), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}" - - if shifted_mantissa == const_2_power_15_or_31: - shifted_mantissa = shifted_mantissa // 2 - shift += 1 - - # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. - shift = offset - shift - - # INT32_MAX, 2^31 - 1 - assert shifted_mantissa <= (const_2_power_15_or_31 - 1), ( - f"Mantissa {shifted_mantissa} exceeds signed max " - f"{const_2_power_15_or_31 - 1}" - ) - - multiplier = shifted_mantissa - - if shift > 62: - multiplier = multiplier >> min(31, shift - 62) - shift = 62 - - assert multiplier >= 0, "Multiplier should be non-negative" - assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" - multipliers.append(multiplier) - shifts.append(shift) - return multipliers, shifts - - -# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be -# const inputs. Create constant operators from the data already initialized. -def _create_const_ops_for_rescale( - tosa_fb, - scale_32, - input_dtype, - node_name, - multipliers, - shifts, - input_zp, - output_zp, - output_dtype, - ts, -): - - multipliers = tosa_fb.addConst( - (len(multipliers),), - ts.DType.INT32 if scale_32 else ts.DType.INT16, - multipliers, - name=node_name + "_multipliers", - ) - shifts = tosa_fb.addConst( - (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" - ) - input_zp = tosa_fb.addConst( - [1], input_dtype, input_zp, name=node_name + "_input_zp" - ) - output_zp = tosa_fb.addConst( - [1], output_dtype, output_zp, name=node_name + "_output_zp" - ) - - return [multipliers.name, shifts.name, input_zp.name, output_zp.name] - - -def build_rescale( - tosa_fb: Any, - scale: list[float], - input_node: Any, - output_name: str, - output_type: Any, - input_zp: list[int], - output_zp: list[int], - rounding_mode: ts.RoundingMode, - per_channel: bool = False, - is_scale32: bool = True, -): - scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 - is_scale32 = False if input_node.dtype == ts.DType.INT48 else True - multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) - rescale_inputs = _create_const_ops_for_rescale( - tosa_fb, - is_scale32, - input_node.dtype, - output_name, - multipliers, - shifts, - input_zp, - output_zp, - output_type, - ts, - ) - attr_rescale = ts.TosaSerializerAttribute() - attr_rescale.RescaleAttribute( - scale32=is_scale32, - rounding_mode=rounding_mode, - per_channel=per_channel, - input_unsigned=False, - output_unsigned=False, - ) - - tosa_fb.addOperator( - ts.Op.RESCALE, - [input_node.name, *rescale_inputs], - [output_name], - attr_rescale, - ) - - return diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 7afa7d9f0de..c6c79f9ad9a 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -12,10 +12,71 @@ import contextvars import re -from typing import List +from typing import Dict, Generic, List, Set, TypeVar from packaging.version import Version +T = TypeVar("T") + + +class TosaSpecMapping(Generic[T]): + def __init__(self): + self._mapping: Dict[TosaSpecification, List[T]] = {} + + def add(self, spec: "TosaSpecification", value: T) -> None: + """ + Adds a value to the mapping for the given TOSA specification. + The specification is normalized to its canonical form, which means that + only the version and profiles are considered, without extensions. + This allows for grouping of values under the same TOSA specification + regardless of the extensions they may have. + """ + + if spec.is_U55_subset or spec.extensions: + raise ValueError( + f"TosaSpecMapping does not support extensions, got: {spec}" + ) + + if isinstance(spec, Tosa_1_00) and len(spec.profiles) > 1: + raise ValueError( + f"TosaSpecMapping does not support multiple profiles, got: {spec}" + ) + + norm_spec = spec._canonical_key() + if norm_spec not in self._mapping: + self._mapping[norm_spec] = [] + self._mapping[norm_spec].append(value) + + @staticmethod + def _get_base_specs(spec: "TosaSpecification") -> List["TosaSpecification"]: + # Handles combined TOSA-1.0+FP+INT, etc. + if isinstance(spec, Tosa_1_00): + profiles: Set[str] = set(spec.profiles) + if profiles == {"FP", "INT"}: + version = spec.version + return [ + TosaSpecification.create_from_string(f"TOSA-{version}+FP"), + TosaSpecification.create_from_string(f"TOSA-{version}+INT"), + ] + return [spec] + + def get(self, spec: "TosaSpecification") -> List[T]: + """ + Returns a list of values associated with the given TOSA specification. + The specification is normalized to its canonical form, which means that + only the version and profiles are considered, without extensions. + """ + + base_specs = self._get_base_specs(spec) + result: List[T] = [] + for base in base_specs: + norm_base = base._canonical_key() + result.extend(self._mapping.get(norm_base, [])) + if len(result) == 0: + raise KeyError(f"No values found for TOSA specification: {spec}") + + return result # Do not deduplicate with set(), as values may be unhashable + class TosaSpecification: """Represent a TOSA specification. @@ -34,6 +95,7 @@ class TosaSpecification: version: Version is_U55_subset: bool + extensions: List[str] def support_integer(self) -> bool: """Return True if integer operations are supported.""" @@ -43,6 +105,18 @@ def support_float(self) -> bool: """Return True if floating-point operations are supported.""" raise NotImplementedError + def support_extension(self, extension: str) -> bool: + """Return True if an extension is supported and enabled. + + Args: + extension (str): Extension name (for example ``int4``, ``bf16``). + + Returns: + bool: True if the extension is valid for the active profiles and selected. + + """ + raise NotImplementedError + def __init__(self, version: Version, extras: List[str]): """Initialize the base specification. @@ -52,6 +126,7 @@ def __init__(self, version: Version, extras: List[str]): """ self.version = version + self.extensions = [] self.is_U55_subset = "u55" in extras if self.is_U55_subset: @@ -89,6 +164,12 @@ def create_from_string(repr: str) -> "TosaSpecification": raise ValueError(f"Failed to parse TOSA specification representation: {repr}") + def _canonical_key(self) -> "TosaSpecification": + """ + Returns a new TosaSpecification instance with only version and profiles (no extensions). + """ + raise NotImplementedError + class Tosa_1_00(TosaSpecification): """Provide TOSA 1.00 profile and extension semantics. @@ -232,6 +313,16 @@ def support_extension(self, extension: str) -> bool: return False + def _canonical_key(self) -> "Tosa_1_00": + """ + Returns a new Tosa_1_00 instance with only major.minor version and profiles (no extensions). + Patch version is set to zero for normalization. + """ + from packaging.version import Version + + norm_version = Version(f"{self.version.major}.{self.version.minor}.0") + return Tosa_1_00(norm_version, self.profiles.copy()) + class TosaLoweringContext: """Manage the TOSA specification context for lowering. diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 14a22298d8a..60ed0376697 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -2,7 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +"""Utility helpers for building TOSA graphs in the Arm backend.""" import logging from typing import Any @@ -26,19 +26,21 @@ def are_fake_tensors_broadcastable( fake_tensors: list[FakeTensor], ) -> tuple[bool, list[int]]: - """ - Determines whether a list of FakeTensors can be broadcast together. + """Determine whether the fake tensors share a broadcastable shape. + Args: - fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors - who's shapes to evaluate + fake_tensors (list[FakeTensor]): Fake tensors whose shapes should + be validated for broadcasting. Returns: - tuple[bool, list[int]]: First element is whether the shapes are - broadcastable. Second element is the common shape if compatible. - If not, empty list. + tuple[bool, list[int]]: Tuple where the first element indicates + whether broadcasting is possible and the second element contains + the broadcast shape. The shape list is empty when broadcasting + fails. Raises: - RuntimeError: If less than 2 tensors are passed in. + RuntimeError: Raised when fewer than two tensors are supplied. + """ if len(fake_tensors) < 1: raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}") @@ -65,26 +67,27 @@ def are_fake_tensors_broadcastable( def broadcast_tensors( tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification ) -> list[Any]: - """ - Given a list of nodes it determines the common shape they broadcast to - and adds the necessary reshape and tile operations to perform the broadcast. + """Broadcast the FX nodes to a shared shape inside the TOSA graph. + + This mirrors ``reshape_for_broadcast`` but also emits the tile operators + needed to materialize the broadcast and supports any number of inputs. Args: - tosa_fb: Tosa graph to add nodes to - nodes (list[Node]): List of nodes to broadcast together - tosa_spec (TosaSpecification): Tosa spec + tosa_fb (Any): TOSA graph builder that receives the broadcast + operators. + nodes (list[Node]): FX nodes whose tensor metadata should be + broadcast. + tosa_spec (TosaSpecification): Active TOSA specification used to + decode tensor metadata. Returns: - list[Any]: List containing the fx.Nodes or TosaSerializerTensors - of the right common shape. Order of output matches order of input. + list[Any]: Broadcast versions of the inputs. Each element is either + the original FX node or a TOSA serializer tensor, ordered to match + ``nodes``. Raises: RuntimeError: If the supplied nodes are not broadcastable. - Note: - This function and `reshape_for_broadcast` both reshape the tensors - for broadcast. However this function also performs the broadcast and - does not have a limit on only two input tensors. """ index_fake_tensors = [node.meta["val"] for node in nodes] broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors) @@ -137,6 +140,17 @@ def broadcast_tensors( def build_reshape_tosa_1_0( tosa_graph, input_name, new_shape, output_name, shape_name_override="" ): + """Insert a TOSA reshape operator using the v1.0 semantics. + + Args: + tosa_graph (Any): Graph builder used to emit TOSA operators. + input_name (str): Name of the tensor that should be reshaped. + new_shape (list[int]): Target tensor shape. + output_name (str): Name assigned to the reshaped tensor. + shape_name_override (str): Optional override for the shape constant + name. + + """ shape = tosa_graph.addConst( np.array(new_shape).shape, ts.DType.SHAPE, @@ -155,6 +169,19 @@ def build_reshape_tosa_1_0( def tosa_shape(shape, dim_order): + """Reorder a shape tuple into TOSA layout while resolving symints. + + Args: + shape (Sequence[int | torch.SymInt]): Original tensor shape, + possibly containing ``torch.SymInt``. + dim_order (Sequence[int]): Desired dimension order for the output + shape. + + Returns: + list[int]: List containing the reordered dimensions where symbolic + values become ``-1``. + + """ reordered = tuple([shape[dim] for dim in dim_order]) # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes, # in TOSA we do not have this concept and instead use -1. @@ -170,6 +197,26 @@ def get_resize_parameters_1d( resize_mode: int, align_corners: bool, ): + """Compute resize coefficients for a single spatial dimension. + + Args: + input_size (int | torch.SymInt): Input size for the axis, possibly + symbolic. + output_size (int | torch.SymInt): Output size for the axis, possibly + symbolic. + resize_mode (int): Target resize mode defined by TOSA. + align_corners (bool): Whether the resize should align the corner + pixels. + + Returns: + tuple[int, int, int, int]: Numerator, denominator, offset, and border + terms encoded as integers. + + Raises: + RuntimeError: If symbolic shapes are used with ``align_corners`` or if + the computed ratio or border is not constant. + + """ # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky. if align_corners: if (not isinstance(input_size, int)) or (not isinstance(output_size, int)): @@ -229,19 +276,23 @@ def get_resize_parameters( resize_mode: int, align_corners: bool, ) -> tuple[torch.IntTensor, ...]: - """Get the tosa.resize parameters based on the input and output size. + """Calculate 2D resize parameters for TOSA emission. Args: - input_size_xy (tuple[int | torch.SymInt]): Size of the input - output_size_xy (tuple[int | torch.SymInt]): Size of the output - resize_mode (tosa.ResizeMode): The TOSA resize mode - align_corners (bool): Align the corners pixels of the input and output + input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height + and width of the input tensor. + output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height + and width of the output tensor. + resize_mode (int): TOSA resize mode used for coefficient generation. + align_corners (bool): Whether to align corner pixels between input and + output. Returns: - scale_n (torch.IntTensor), scale_d (torch.IntTensor), - offset (torch.IntTensor), border (torch.IntTensor) - """ + tuple[torch.IntTensor, ...]: Four-element tuple of tensors describing + the scale numerator, scale denominator, offset, and border for Y + and X dimensions. + """ # Get the parameters for each dimension independently y_params = get_resize_parameters_1d( input_size_xy[0], output_size_xy[0], resize_mode, align_corners diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index 5d33812df97..50a3e2eac88 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -167,14 +167,14 @@ def __init__( self, model_name: str, fp32_model: torch.nn.Module, - int8_model: torch.nn.Module, + quant_model: torch.nn.Module, example_input: Tuple[torch.Tensor], tosa_output_path: Optional[str], ) -> None: self.model_name = model_name self.fp32_model = fp32_model - self.int8_model = int8_model + self.quant_model = quant_model self.example_input = example_input if tosa_output_path: @@ -192,12 +192,12 @@ def get_model_error(self) -> defaultdict: mean_absolute_error """ fp32_outputs, _ = tree_flatten(self.fp32_model(*self.example_input)) - int8_outputs, _ = tree_flatten(self.int8_model(*self.example_input)) + quant_outputs, _ = tree_flatten(self.quant_model(*self.example_input)) model_error_dict = defaultdict(list) - for fp32_output, int8_output in zip(fp32_outputs, int8_outputs): - difference = fp32_output - int8_output + for fp32_output, quant_output in zip(fp32_outputs, quant_outputs): + difference = fp32_output - quant_output # Avoid divide by zero: elements where fp32 == 0 produce 0% contribution percentage_error = torch.where( fp32_output != 0, @@ -238,7 +238,6 @@ def evaluate(self) -> dict[str, Any]: if self.tosa_output_path: # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. - # pyre-ignore[16] output_metrics["metrics"][ # type: ignore[index] "compression_ratio" ] = self.get_compression_ratio() @@ -253,14 +252,14 @@ def __init__( self, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, batch_size: int, validation_dataset_path: str, ) -> None: super().__init__( - model_name, fp32_model, int8_model, example_input, tosa_output_path + model_name, fp32_model, quant_model, example_input, tosa_output_path ) self.__batch_size = batch_size @@ -280,7 +279,7 @@ def from_config( cls, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, config: dict[str, Any], @@ -292,7 +291,7 @@ def from_config( return cls( model_name, fp32_model, - int8_model, + quant_model, example_input, tosa_output_path, batch_size=config["batch_size"], @@ -303,10 +302,9 @@ def evaluate(self) -> dict[str, Any]: # Load dataset and compute top-1 / top-5 dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path) top1_correct, top5_correct = GenericModelEvaluator.evaluate_topk( - self.int8_model, dataset, self.__batch_size, topk=5 + self.quant_model, dataset, self.__batch_size, topk=5 ) output = super().evaluate() - output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct} return output @@ -318,14 +316,14 @@ def __init__( self, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, batch_size: int, validation_dataset_path: str, ) -> None: super().__init__( - model_name, fp32_model, int8_model, example_input, tosa_output_path + model_name, fp32_model, quant_model, example_input, tosa_output_path ) self.__batch_size = batch_size self.__validation_set_path = validation_dataset_path @@ -344,7 +342,7 @@ def from_config( cls, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, config: dict[str, Any], @@ -356,7 +354,7 @@ def from_config( return cls( model_name, fp32_model, - int8_model, + quant_model, example_input, tosa_output_path, batch_size=config["batch_size"], @@ -367,7 +365,7 @@ def evaluate(self) -> dict[str, Any]: # Load dataset and compute top-1 / top-5 dataset = DeiTTinyEvaluator.__load_dataset(self.__validation_set_path) top1, top5 = GenericModelEvaluator.evaluate_topk( - self.int8_model, dataset, self.__batch_size, topk=5 + self.quant_model, dataset, self.__batch_size, topk=5 ) output = super().evaluate() output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5} @@ -381,14 +379,14 @@ def __init__( self, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, batch_size: int, validation_dataset_path: str, ) -> None: super().__init__( - model_name, fp32_model, int8_model, example_input, tosa_output_path + model_name, fp32_model, quant_model, example_input, tosa_output_path ) self.__batch_size = batch_size self.__validation_set_path = validation_dataset_path @@ -407,7 +405,7 @@ def from_config( cls, model_name: str, fp32_model: Module, - int8_model: Module, + quant_model: Module, example_input: Tuple[torch.Tensor], tosa_output_path: str | None, config: dict[str, Any], @@ -415,7 +413,7 @@ def from_config( return cls( model_name, fp32_model, - int8_model, + quant_model, example_input, tosa_output_path, batch_size=config["batch_size"], @@ -425,7 +423,7 @@ def from_config( def evaluate(self) -> dict[str, Any]: dataset = ResNet18Evaluator.__load_dataset(self.__validation_set_path) top1, top5 = GenericModelEvaluator.evaluate_topk( - self.int8_model, dataset, self.__batch_size, topk=5 + self.quant_model, dataset, self.__batch_size, topk=5 ) output = super().evaluate() output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5} @@ -464,8 +462,9 @@ def evaluator_calibration_data( def evaluate_model( model_name: str, intermediates: str, + target: str, model_fp32: torch.nn.Module, - model_int8: torch.nn.Module, + model_quant: torch.nn.Module, example_inputs: Tuple[torch.Tensor], evaluator_name: str, evaluator_config: str | None, @@ -487,7 +486,7 @@ def evaluate_model( init_evaluator = factory( model_name, model_fp32, - model_int8, + model_quant, example_inputs, str(tosa_paths[0]), config, @@ -498,11 +497,11 @@ def evaluate_model( ) else: init_evaluator = evaluator( - model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0]) + model_name, model_fp32, model_quant, example_inputs, str(tosa_paths[0]) ) quant_metrics = init_evaluator.evaluate() - output_json_path = intermediates_path / "quant_metrics.json" + output_json_path = intermediates_path / f"{target}-quant_metrics.json" with output_json_path.open("w") as json_file: json.dump(quant_metrics, json_file) diff --git a/backends/arm/vgf/backend.py b/backends/arm/vgf/backend.py index 82d200f44fd..d22dc27afa0 100644 --- a/backends/arm/vgf/backend.py +++ b/backends/arm/vgf/backend.py @@ -10,6 +10,7 @@ # this form is used where the final JIT compile is performed on target (in the # runtime delegate executorch::runtime::BackendInterface::init # +"""Ahead-of-time Arm VGF backend built on the shared TOSA pipeline.""" import logging import os @@ -17,13 +18,24 @@ import tempfile from typing import final, List -from executorch.backends.arm.tosa.backend import ( +from executorch.backends.arm.tosa.backend import ( # type: ignore[import-not-found] arm_get_first_delegation_tag, TOSABackend, ) -from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec -from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult -from executorch.exir.backend.compile_spec_schema import CompileSpec + +from executorch.backends.arm.vgf.compile_spec import ( # type: ignore[import-not-found] + VgfCompileSpec, +) +from executorch.backends.arm.vgf.model_converter import ( # type: ignore[import-not-found] + require_model_converter_binary, +) +from executorch.exir.backend.backend_details import ( # type: ignore[import-not-found] + BackendDetails, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) from torch.export.exported_program import ExportedProgram # debug functionality @@ -32,9 +44,11 @@ @final class VgfBackend(BackendDetails): - """ - BackendDetails subclass for delegation to VGF compatible devices. This enables - encapsulated TOSA on target device and JIT compilation on suitable platforms. + """BackendDetails subclass for delegation to VGF compatible devices. + + This enables encapsulated TOSA on target device and JIT compilation on + suitable platforms. + """ @staticmethod @@ -43,9 +57,18 @@ def _compile_tosa_flatbuffer( compile_spec: VgfCompileSpec, tag_name: str = "", ) -> bytes: - """ - Static helper method to do the compilation of the TOSA flatbuffer - representation to a target specific binary stream. + """Compile a TOSA flatbuffer into a target-specific binary stream. + + Args: + tosa_flatbuffer (bytes): Serialized TOSA graph produced by + ``TOSABackend``. + compile_spec (VgfCompileSpec): Compile specification providing + converter flags and artifact paths. + tag_name (str): Optional suffix used when producing debug outputs. + + Returns: + bytes: Target-specific VGF binary stream. + """ compile_flags = compile_spec.compiler_flags artifact_path = compile_spec.get_intermediate_path() @@ -58,6 +81,17 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: + """Lower the exported program and compile it for a VGF target. + + Args: + edge_program (ExportedProgram): Program to lower to VGF. + compile_specs (List[CompileSpec]): Serialized VGF compile specs + supplied by the frontend. + + Returns: + PreprocessResult: Result containing the compiled VGF binary. + + """ logger.info(f"{VgfBackend.__name__} preprocess") compile_spec = VgfCompileSpec.from_list(compile_specs) @@ -87,6 +121,20 @@ def vgf_compile( artifact_path: str | None = None, tag_name: str = "", ): + """Invoke the VGF compiler to convert a TOSA flatbuffer. + + Args: + tosa_flatbuffer (bytes): Serialized TOSA graph produced by + ``TOSABackend``. + compile_flags (List[str]): Command-line flags forwarded to + ``model-converter``. + artifact_path (str | None): Directory where debug artifacts are saved. + tag_name (str): Optional suffix used when producing debug outputs. + + Returns: + bytes: Compiled VGF binary emitted by ``model-converter``. + + """ with tempfile.TemporaryDirectory() as tmpdir: # We currently write out a flatbuffer as input to the converter @@ -96,9 +144,10 @@ def vgf_compile( f.write(tosa_flatbuffer) additional_flags = " ".join(compile_flags) + converter_binary = require_model_converter_binary() vgf_path = tosa_path + ".vgf" conversion_command = ( - f"model-converter {additional_flags} -i {tosa_path} -o {vgf_path}" + f"{converter_binary} {additional_flags} -i {tosa_path} -o {vgf_path}" ) try: subprocess.run( diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py index 860c8bf3687..f0ae83c654a 100644 --- a/backends/arm/vgf/compile_spec.py +++ b/backends/arm/vgf/compile_spec.py @@ -15,19 +15,23 @@ class VgfCompileSpec(ArmCompileSpec): - """ - Compile spec for VGF compatible targets. - - Args: - tosa_spec: TOSA specification that should be targeted. - compiler_flags: Extra compiler flags for converter_backend. - """ + """Compile specification for VGF-compatible targets.""" def __init__( self, tosa_spec: TosaSpecification | str | None = None, compiler_flags: list[str] | None = None, ): + """Normalise inputs and populate the underlying Arm compile spec. + + Args: + tosa_spec (TosaSpecification | str | None): TOSA specification to + target. Strings are parsed via + :meth:`TosaSpecification.create_from_string`. Defaults to + ``"TOSA-1.0+FP"``. + compiler_flags (list[str] | None): Optional converter-backend flags. + + """ if tosa_spec is None: tosa_spec = "TOSA-1.0+FP" if isinstance(tosa_spec, str): @@ -39,7 +43,7 @@ def __init__( self.validate() def validate(self): - """Throws an error if the compile spec is not valid.""" + """Validate the configuration against VGF-supported TOSA profiles.""" tosa_version = self.tosa_spec.version # type: ignore[attr-defined] tosa_profiles = self.tosa_spec.profiles # type: ignore[attr-defined] @@ -63,4 +67,5 @@ def validate(self): @classmethod def get_output_format(cls) -> str: + """Return the artifact format emitted by this compile spec.""" return "vgf" diff --git a/backends/arm/vgf/model_converter.py b/backends/arm/vgf/model_converter.py new file mode 100644 index 00000000000..dffbf76f26a --- /dev/null +++ b/backends/arm/vgf/model_converter.py @@ -0,0 +1,34 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from shutil import which +from typing import Optional + +MODEL_CONVERTER_BINARY = "model-converter" +_MODEL_CONVERTER_FALLBACK_BINARY = "model_converter" + + +def find_model_converter_binary() -> Optional[str]: + """Return the name of the first model converter executable on PATH.""" + + for candidate in (MODEL_CONVERTER_BINARY, _MODEL_CONVERTER_FALLBACK_BINARY): + if which(candidate): + return candidate + return None + + +def require_model_converter_binary() -> str: + """Return a usable model converter executable or raise a helpful error.""" + + binary = find_model_converter_binary() + if binary is None: + tried = ", ".join((MODEL_CONVERTER_BINARY, _MODEL_CONVERTER_FALLBACK_BINARY)) + raise RuntimeError( + "Unable to locate a model converter executable. " + f"Tried: {tried}. Ensure the Model Converter is installed and on PATH." + ) + return binary diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 3dd612e650e..69882218bf5 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -144,7 +144,7 @@ def convert_pt2( # It is however useful for unit tests to separate the converted model from the # fused model, to be able to get reference numerics. # If this does not apply, please use quantize_pt2 instead. -def fuse_pt2( +def apply_pre_edge_transform_passes( converted_program: ExportedProgram, quantizer: CadenceQuantizer, ) -> ExportedProgram: @@ -229,7 +229,7 @@ def quantize_pt2( # Apply quant fusion to the exported program program = torch.export.export(converted_gm, inputs, strict=True) - fused_program = fuse_pt2(program, quantizer) + fused_program = apply_pre_edge_transform_passes(program, quantizer) if dump_graphs: logging.info("Graph after quantization and fusion:") diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 20719322e82..cf4fa484997 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -18,8 +18,8 @@ from executorch.backends.cadence.aot.compiler import ( _lower_ep_to_cadence_gen_etrecord, + apply_pre_edge_transform_passes, convert_pt2, - fuse_pt2, prepare_pt2, ) @@ -66,7 +66,7 @@ def export_model( ep = torch.export.export(converted_model, example_inputs, strict=True) # Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2) - ep = fuse_pt2(ep, quantizer) + ep = apply_pre_edge_transform_passes(ep, quantizer) # Get edge program after Cadence specific passes exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord( diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 0220baa593f..c1fba3b110b 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -351,10 +351,6 @@ def register_fake( "quantized_matmul_asym8uxasym8u_asym8u.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)" ) -lib.define( - "convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " - "int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" -) lib.define( "transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)" @@ -489,8 +485,28 @@ def register_fake( # ------------------------------------ # # Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out) lib.define( - "convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, " - "int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" + "conv1d(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, " + "int groups) -> Tensor" +) +lib.define( + "conv1d.out(Tensor input, Tensor weight, Tensor bias, int[1] stride, SymInt[1] padding, int[1] dilation, " + "int groups, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "conv2d(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, " + "int groups) -> Tensor" +) +lib.define( + "conv2d.out(Tensor input, Tensor weight, Tensor bias, int[2] stride, SymInt[2] padding, int[2] dilation, " + "int groups, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "conv3d(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, " + "int groups) -> Tensor" +) +lib.define( + "conv3d.out(Tensor input, Tensor weight, Tensor bias, int[3] stride, SymInt[3] padding, int[3] dilation, " + "int groups, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " @@ -2152,8 +2168,8 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta( return src.new_empty(out_size, dtype=src.dtype) -@register_fake("cadence::convolution") -def convolution_meta( +@register_fake("cadence::conv1d") +def conv1d_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -2161,32 +2177,93 @@ def convolution_meta( padding: Tuple[int], dilation: Tuple[int], groups: int, - channel_last: bool = False, ) -> torch.Tensor: - if channel_last: - out_channels, *kernel_size, _ = weight.shape - else: - out_channels, _, *kernel_size = weight.shape + assert ( + len(weight.shape) == 3 + ), f"Conv1d expects a 3D weight, got {len(weight.shape)}D" + out_channels, _, kernel_size = weight.shape in_size = input.shape - # Assert that the input tensor has at least 3 dimensions, and at most 6 - assert len(in_size) > 2 - assert len(in_size) < 6 + assert len(in_size) == 3, f"conv1d expects 3D input, got {len(in_size)}D" - # Compute the output tensor size - output_size = ( - get_conv1d_output_size( - in_size, - out_channels, - stride[0], - padding[0], - dilation[0], - kernel_size[0], - channel_last, - ) - if len(in_size) == 3 - else get_conv2d_output_size( - in_size, out_channels, stride, padding, dilation, kernel_size, channel_last - ) + output_size = get_conv1d_output_size( + in_size, + out_channels, + stride[0], + padding[0], + dilation[0], + kernel_size, + False, + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::conv2d") +def conv2d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, +) -> torch.Tensor: + assert ( + len(weight.shape) == 4 + ), f"Conv2d expects a 4D weight, got {len(weight.shape)}D" + out_channels, _, *kernel_size = weight.shape + in_size = input.shape + assert len(in_size) == 4, f"conv2d expects 4D input, got {len(in_size)}D" + + output_size = get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, False + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::conv3d") +def conv3d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int, int, int], + padding: Tuple[int, int, int], + dilation: Tuple[int, int, int], + groups: int, +) -> torch.Tensor: + assert ( + len(weight.shape) == 5 + ), f"Conv3d expects a 5D weight, got {len(weight.shape)}D" + out_channels, _, *kernel_size = weight.shape + in_size = input.shape + assert len(in_size) == 5, f"conv3d expects 5D input, got {len(in_size)}D" + + # Helper to compute 3D convolution output size + def get_conv3d_output_size( + in_size: torch.Size, + out_channels: int, + stride: Tuple[int, int, int], + padding: Tuple[int, int, int], + dilation: Tuple[int, int, int], + kernel_size: list[int], + ) -> torch.Size: + N, C, D, H, W = in_size + + dout = (D + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[ + 0 + ] + 1 + hout = (H + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[ + 1 + ] + 1 + wout = (W + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) // stride[ + 2 + ] + 1 + + return torch.Size((N, out_channels, dout, hout, wout)) + + output_size = get_conv3d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size ) return input.new_empty(output_size, dtype=input.dtype) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 5a8cba0361d..bc589325025 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1334,8 +1334,25 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... -@impl_tracked(m, "convolution") -def convolution( +@impl_tracked(m, "conv1d") +def conv1d( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int], + padding: tuple[int], + dilation: tuple[int], + groups: int, +) -> torch.Tensor: + conv_out = torch.nn.functional.conv1d( + input_tensor, weight, bias, stride[0], padding[0], dilation[0], groups + ) + + return conv_out + + +@impl_tracked(m, "conv2d") +def conv2d( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, @@ -1343,39 +1360,27 @@ def convolution( padding: tuple[int, int], dilation: tuple[int, int], groups: int, - channel_last: bool = False, ) -> torch.Tensor: - conv_is_1d = len(input_tensor.shape) == 3 - if channel_last: - if conv_is_1d: - input_tensor = input_tensor.movedim(-1, 1).contiguous() - if len(weight.shape) != 3: - raise ValueError("Weight tensor must be 3D if input is 3D") - weight = weight.movedim(-1, 1).contiguous() - else: - input_tensor = input_tensor.movedim(-1, -3) - if len(weight.shape) != 4: - raise ValueError("Weight tensor must be 4D if input is nd > 3") - weight = torch.permute(weight, (0, -1, 1, 2)).contiguous() + conv_out = torch.nn.functional.conv2d( + input_tensor, weight, bias, stride, padding, dilation, groups + ) - _stride: tuple[int, int] | int = stride - _padding: tuple[int, int] | int = padding - _dilation: tuple[int, int] | int = dilation + return conv_out - if conv_is_1d: - conv = torch.nn.functional.conv1d - _stride = stride[0] - _padding = padding[0] - _dilation = dilation[0] - else: - conv = torch.nn.functional.conv2d - conv_out = conv(input_tensor, weight, bias, _stride, _padding, _dilation, groups) - if channel_last: - if conv_is_1d: - conv_out = conv_out.movedim(1, -1).contiguous() - else: - conv_out = conv_out.movedim(-3, -1).contiguous() +@impl_tracked(m, "conv3d") +def conv3d( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int, int], + padding: tuple[int, int, int], + dilation: tuple[int, int, int], + groups: int, +) -> torch.Tensor: + conv_out = torch.nn.functional.conv3d( + input_tensor, weight, bias, stride, padding, dilation, groups + ) return conv_out diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 289c5ffeeec..9dc695c68af 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -6,9 +6,8 @@ # pyre-strict -import logging from dataclasses import dataclass, field -from typing import cast, List, Optional, Sequence, Set, Type +from typing import cast, List, Optional, Set, Type # Import these for the cadence function signatures. import executorch.backends.cadence.aot.ops_registrations # noqa: F401 @@ -47,19 +46,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveDetachCopyPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.detach_copy.default: - return super().call_operator(op, args, kwargs, meta) +class RemoveDetachCopyPass(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.detach_copy.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True # The following class consolidates passes to remove ops that are redundant: @@ -72,117 +68,123 @@ class RemoveRedundantOps: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveZeroSizedCatArgsPass(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.cat.default: - return super().call_operator(op, args, kwargs, meta) - - # Remove any zero-sized tensor arg to form a new args list. - cat_inputs: list[ProxyValue] = [] - for arg in cast(Sequence[ProxyValue], args[0]): - if arg.to_tensor().numel() > 0: - cat_inputs.append(arg) +class RemoveZeroSizedCatArgsPass(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.cat.default] - # If all the tensors were empty, we just return an empty tensor with - # the right shape. + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get the cat inputs (first argument is a list of tensors) + cat_inputs_arg = node.args[0] + + # Assert that cat_inputs_arg is iterable + assert isinstance( + cat_inputs_arg, (list, tuple) + ), "cat_inputs_arg must be a sequence type" + + # Filter out zero-sized tensors + cat_inputs: list[Node] = [] + for arg in cat_inputs_arg: + if isinstance(arg, Node) and arg.meta.get("val") is not None: + if arg.meta["val"].numel() > 0: + cat_inputs.append(arg) + + # If all tensors were empty, create a full op with the right shape if not cat_inputs: - empty_shape = meta["val"].shape - dtype = meta["val"].dtype - return super().call_operator( - exir_ops.edge.aten.full.default, - (tuple(empty_shape), 0), - {"dtype": dtype}, - meta, - ) + empty_shape = node.meta["val"].shape + dtype = node.meta["val"].dtype + # Create a new full node + with node.graph.inserting_before(node): + full_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=(tuple(empty_shape), 0), + kwargs={"dtype": dtype}, + ) + full_node.meta = node.meta.copy() + node.replace_all_uses_with(full_node) + return True - # If there was only one tensor in the cat_inputs list, - # we can safely erase this cat op. + # If only one tensor remains, replace with it if len(cat_inputs) == 1: - return cat_inputs[0] + node.replace_all_uses_with(cat_inputs[0]) + return True + + # If the number of inputs changed, update the cat args + if len(cat_inputs) < len(cat_inputs_arg): + # Update the first argument with filtered inputs + new_args = list(node.args) + new_args[0] = cat_inputs + node.args = tuple(new_args) + return True - # Otherwise, we replace args[0] with cat_inputs. - new_args = list(args) - # pyre error introduced after D66937105 - new_args[0] = cat_inputs # pyre-ignore[6] - return super().call_operator(op, tuple(new_args), kwargs, meta) + # No changes needed + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveNopExpandOpPass(ExportPass): +class RemoveNopExpandOpPass(RemoveOrReplacePassInterface): """ For an expand op, if the operator shape matches the expand shape, then the expand is a nop. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.expand_copy, - exir_ops.edge.aten.expand, - }: - return super().call_operator(op, args, kwargs, meta) - - # Parse the args, and check for nop condition - arg0 = cast(ProxyValue, args[0]) - arg1 = cast(Sequence[int], args[1]) - in_tensor = arg0.to_tensor() - if list(in_tensor.shape) == list(arg1): - return arg0 + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.expand.default, + ] - return super().call_operator(op, args, kwargs, meta) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + if input_node.meta["val"].shape == node.meta["val"].shape: + node.replace_all_uses_with(input_node) + return True + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveToOpsPass(ExportPass): +class RemoveToOpsPass(RemoveOrReplacePassInterface): # aten.to.* as of now are all nops - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in ( + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.to.dtype, exir_ops.edge.aten.to.dtype_layout, - ): - return super().call_operator(op, args, kwargs, meta) + ] - logging.debug(f"Erasing to.dtype node (target = {op})") - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveZeroSizedConstantPadNd(ExportPass): - def call_operator( - self, - op, # pyre-ignore - args: tuple[ProxyValue, tuple[int, ...], Argument], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.constant_pad_nd.default: - return super().call_operator(op, args, kwargs, meta) +class RemoveZeroSizedConstantPadNd(RemoveOrReplacePassInterface): + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.constant_pad_nd.default] + + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get padding argument (second argument) + if len(node.args) < 2: + return False - input_tensor = args[0] - padding = args[1] + padding = node.args[1] + if not isinstance(padding, (list, tuple)): + return False + # If any padding value is non-zero, keep the node if any(x != 0 for x in padding): - return super().call_operator(op, args, kwargs, meta) + return False - logging.debug(f"Erasing 0 sized constant pad nd node with {input_tensor}") - return input_tensor + # All padding is zero, replace with input + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -210,40 +212,37 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopLinalgVectorNormOpPass(ExportPass): +class RemoveNopLinalgVectorNormOpPass(RemoveOrReplacePassInterface): """ If the norm is applied over a dimension that is size 1, it can be eliminated. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op is not exir_ops.edge.aten.linalg_vector_norm.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.linalg_vector_norm.default] + def maybe_remove_or_replace(self, node: Node) -> bool: # If the op has three args or less, it can't be a nop - if len(args) <= 3: - return super().call_operator(op, args, kwargs, meta) + if len(node.args) <= 3: + return False # If dim is None, or keepdim is False, it is not a nop - dim = cast(Optional[tuple[int, ...]], args[2]) - keepdim = cast(bool, args[3]) + dim = cast(Optional[tuple[int, ...]], node.args[2]) + keepdim = cast(bool, node.args[3]) if dim is None or not keepdim: - return super().call_operator(op, args, kwargs, meta) + return False # If the norm has 4 args and keepdim is True, check if dim is not None # and if the dimensions in dim are size 1. If not, the norm is not a nop. - t = cast(ProxyValue, args[0]) - shape = t.to_tensor().shape - if len(args) < 4: + input_node = node.args[0] + assert isinstance(input_node, Node) + shape = input_node.meta["val"].shape + if len(node.args) < 4: for d in dim: if shape[d] != 1: - return super().call_operator(op, args, kwargs, meta) + return False - return t + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -358,23 +357,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveCloneOpPass(ExportPass): +class RemoveCloneOpPass(RemoveOrReplacePassInterface): # If the op is a clone op, return the input and eliminate the op - def call_operator( - self, - op, # pyre-ignore - args: tuple[ProxyValue], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.clone.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.clone.default] - return args[0] + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveContiguousOpPass(ExportPass): +class RemoveContiguousOpPass(RemoveOrReplacePassInterface): """ This is based on the assumption that all tensors are contiguous in ExecuTorch and after cadence passes, and we should revisit this if that assumption is no longer true. @@ -382,43 +379,37 @@ class RemoveContiguousOpPass(ExportPass): original graph module. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.contiguous.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.contiguous.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class RemoveAliasCopyOpPass(ExportPass): +class RemoveAliasCopyOpPass(RemoveOrReplacePassInterface): """ alias_copy is a no-op and can be removed. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.alias_copy.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.alias_copy.default] - assert len(args) == 1 - return cast(ProxyValue, args[0]) + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + node.replace_all_uses_with(input_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopRequantizeOpPass(ExportPass): +class RemoveNopRequantizeOpPass(RemoveOrReplacePassInterface): """ For a requantize op, if the following three conditions are satisfied: 1. the in_scale matches the out_scale @@ -427,100 +418,96 @@ class RemoveNopRequantizeOpPass(ExportPass): then the requantize op is redundant, and can be eliminated """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.cadence.requantize.per_tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.requantize.per_tensor] - # Parse the args - (X, in_scale, in_zero_point, out_scale, out_zero_point, out_dtype) = cast( - tuple[ProxyValue, int, float, int, float, torch.dtype], args - ) - in_dtype = X.to_tensor().dtype + def maybe_remove_or_replace(self, node: Node) -> bool: + input_node = node.args[0] + assert isinstance(input_node, Node) + in_scale = node.args[1] + in_zero_point = node.args[2] + out_scale = node.args[3] + out_zero_point = node.args[4] + out_dtype = node.args[5] + in_dtype = input_node.meta["val"].dtype # Check the three conditions if ( in_scale == out_scale and in_zero_point == out_zero_point and in_dtype == out_dtype ): - return cast(ProxyValue, args[0]) - - return super().call_operator(op, args, kwargs, meta) + node.replace_all_uses_with(input_node) + return True + return False @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopMulOpPass(ExportPass): +class RemoveNopMulOpPass(RemoveOrReplacePassInterface): """ If a mul op is multiplying two tensors with the same shape and one of those tensors is all zeros, return the zero tensor instead. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.mul.Tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - # Parse the args - (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + def maybe_remove_or_replace(self, node: Node) -> bool: + input1 = node.args[0] + input2 = node.args[1] + assert isinstance(input1, Node) + assert isinstance(input2, Node) # Check if both inputs have the same shape - if input1.to_tensor().shape != input2.to_tensor().shape: - return super().call_operator(op, args, kwargs, meta) + if input1.meta["val"].shape != input2.meta["val"].shape: + return False # Check if one of the inputs is a zero tensor - if input1.node.target == exir_ops.edge.aten.full.default: - if input1.node.args[1] == 0: - return input1 - elif input2.node.target == exir_ops.edge.aten.full.default: - if input2.node.args[1] == 0: - return input2 + if input1.target == exir_ops.edge.aten.full.default: + if input1.args[1] == 0: + node.replace_all_uses_with(input1) + return True + elif input2.target == exir_ops.edge.aten.full.default: + if input2.args[1] == 0: + node.replace_all_uses_with(input2) + return True - return super().call_operator(op, args, kwargs, meta) + return False @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopAddOpPass(ExportPass): +class RemoveNopAddOpPass(RemoveOrReplacePassInterface): """ If an add op is adding two tensors with the same shape and one of those tensors is all zeros, return the other tensor instead. """ - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.aten.add.Tensor: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.add.Tensor] - # Parse the args - (input1, input2) = cast(tuple[ProxyValue, ProxyValue], args) + def maybe_remove_or_replace(self, node: Node) -> bool: + input1 = node.args[0] + input2 = node.args[1] + assert isinstance(input1, Node) + assert isinstance(input2, Node) # Check if both inputs have the same shape - if input1.to_tensor().shape != input2.to_tensor().shape: - return super().call_operator(op, args, kwargs, meta) + if input1.meta["val"].shape != input2.meta["val"].shape: + return False # Check if one of the inputs is a zero tensor - if input1.node.target == exir_ops.edge.aten.full.default: - if input1.node.args[1] == 0: - return input2 - elif input2.node.target == exir_ops.edge.aten.full.default: - if input2.node.args[1] == 0: - return input1 + if input1.target == exir_ops.edge.aten.full.default: + if input1.args[1] == 0: + node.replace_all_uses_with(input2) + return True + elif input2.target == exir_ops.edge.aten.full.default: + if input2.args[1] == 0: + node.replace_all_uses_with(input1) + return True - return super().call_operator(op, args, kwargs, meta) + return False @register_cadence_pass(CadencePassAttribute(opt_level=2)) @@ -749,17 +736,17 @@ def get_squeeze_indices(self, view_node: Node) -> List[int]: return squeeze_indices - def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None: + def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> bool: if view_node in visited_view_nodes: - return + return False squeeze_indices = self.get_squeeze_indices(view_node) if not squeeze_indices: - return + return False # Only handle simple chains for now. if len(view_node.users) != 1: - return + return False node = next(iter(view_node.users)) # Traverse down from the node until finding another view op. @@ -767,9 +754,9 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None while node.target != exir_ops.edge.aten.view_copy.default: # Only handle simple chains for now if len(node.users) != 1: - return + return False if node.target not in self.intermediate_ops: - return + return False if node.target == exir_ops.edge.aten.slice_copy.Tensor: intermediate_slices.append(node) node = next(iter(node.users)) @@ -792,18 +779,22 @@ def handle_squeeze(self, view_node: Node, visited_view_nodes: Set[Node]) -> None # Skip the initial view node. input_node = cast(Node, get_arg(view_node, "input")) view_node.replace_all_uses_with(input_node) + return True def call(self, graph_module: torch.fx.GraphModule) -> PassResult: visited_view_nodes = set() + modified = False for view_node in graph_module.graph.find_nodes( op="call_function", target=exir_ops.edge.aten.view_copy.default, sort=True ): - self.handle_squeeze(view_node, visited_view_nodes) + modified |= self.handle_squeeze(view_node, visited_view_nodes) - graph_module.graph.eliminate_dead_code() - graph_module.recompile() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) - return super().call(graph_module) + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -826,23 +817,27 @@ class RemoveBranchedQuantDequant(ExportPass): } def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.remove_branched( + modified = self.remove_branched( graph_module, self.quantize_op_packets, self.dequantize_op_packets ) - self.remove_branched( + modified |= self.remove_branched( graph_module, self.dequantize_op_packets, self.quantize_op_packets ) - graph_module.graph.eliminate_dead_code() - result = super().call(graph_module) - return result + if modified: + graph_module.graph.eliminate_dead_code() + result = super().call(graph_module) + return result + + return PassResult(graph_module, False) def remove_branched( self, graph_module: torch.fx.GraphModule, producer_pkts: set[EdgeOpOverloadPacket], consumer_pkts: set[EdgeOpOverloadPacket], - ) -> None: + ) -> bool: + modified = False for node in graph_module.graph.nodes: if ( node.op != "call_function" @@ -866,61 +861,62 @@ def remove_branched( continue user.replace_all_uses_with(node.args[0]) + modified = True + return modified -class RemoveCatFromSliceCopyPass(ExportPass): + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveCatFromSliceCopyPass(RemoveOrReplacePassInterface): """ Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed to the slice_copy. """ - def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: - for slice_copy_node in graph_module.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor - ): - cat_node = cast(Node, get_arg(slice_copy_node, "input")) - slice_dim = cast(int, get_arg(slice_copy_node, "dim")) - start_idx = cast(int, get_arg(slice_copy_node, "start")) - end_idx = cast(int, get_arg(slice_copy_node, "end")) - step = cast(int, get_arg(slice_copy_node, "step")) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] - if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: - continue + def maybe_remove_or_replace(self, node: Node) -> bool: + cat_node = cast(Node, get_arg(node, "input")) + slice_dim = cast(int, get_arg(node, "dim")) + start_idx = cast(int, get_arg(node, "start")) + end_idx = cast(int, get_arg(node, "end")) + step = cast(int, get_arg(node, "step")) - # Make sure cat and slice happens on the same dimension. - cat_dim = cast(Node, get_arg(cat_node, "dim")) - if cat_dim != slice_dim: - continue + if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: + return False - # Canonicalize slice indices. - cat_output_shape = cat_node.meta["val"].shape - if start_idx is None: - start_idx = 0 - elif start_idx < 0: - start_idx += cat_output_shape[cat_dim] - if end_idx is None or end_idx > cat_output_shape[cat_dim]: - end_idx = cat_output_shape[cat_dim] - elif end_idx < 0: - end_idx += cat_output_shape[cat_dim] - - offset = 0 - for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): - cat_input_shape = cat_input_node.meta["val"].shape - - # Check if the slice range overlaps with the cat input range. - if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: - slice_copy_node.replace_input_with(cat_node, cat_input_node) - set_arg(slice_copy_node, "start", start_idx - offset) - set_arg(slice_copy_node, "end", end_idx - offset) - break - - offset += cat_input_shape[cat_dim] + # Make sure cat and slice happens on the same dimension. + cat_dim = cast(int, get_arg(cat_node, "dim")) + if cat_dim != slice_dim: + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self._remove_unused_cat(graph_module) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() - return super().call(graph_module) + # Canonicalize slice indices. + cat_output_shape = cat_node.meta["val"].shape + if start_idx is None: + start_idx = 0 + elif start_idx < 0: + start_idx += cat_output_shape[cat_dim] + if end_idx is None or end_idx > cat_output_shape[cat_dim]: + end_idx = cat_output_shape[cat_dim] + elif end_idx < 0: + end_idx += cat_output_shape[cat_dim] + + offset = 0 + for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")): + cat_input_shape = cat_input_node.meta["val"].shape + + # Check if the slice range overlaps with the cat input range. + if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: + node.replace_input_with(cat_node, cat_input_node) + set_arg(node, "start", start_idx - offset) + set_arg(node, "end", end_idx - offset) + return True + + offset += cat_input_shape[cat_dim] + + return False class CommonRemovePasses: @@ -929,7 +925,6 @@ class CommonRemovePasses: RemoveAliasCopyOpPass, RemoveNopExpandOpPass, RemoveNopSliceOrViewOpPass, - RemoveNopSelectOpPass, RemoveToOpsPass, RemoveZeroSizedCatArgsPass, RemovePermutesAroundElementwiseOps, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index d430e95c470..c383adf4162 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -34,6 +34,7 @@ CadencePassAttribute, none_throws, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -68,182 +69,195 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): +class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface): """ A where op with a logical_not and a boolean tensor can be replaced by a where op with flipped inputs and the initial boolean tensor. """ - def replace_logical_nop_where_with_where( - self, graph_module: torch.fx.GraphModule - ) -> None: - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in where nodes - if node.target != exir_ops.edge.aten.where.self: - continue + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] - # If the third arg is not a logical_not, bail. - if node.args[0].target != exir_ops.edge.aten.logical_not.default: - continue + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # If the first arg is not a logical_not, bail. + if not isinstance(node.args[0], torch.fx.Node): + return False - # Get the third arg node and its input - logical_not_node = node.args[0] - logical_not_input_node = logical_not_node.args[0] + logical_not_node = cast(torch.fx.Node, node.args[0]) + if logical_not_node.target != exir_ops.edge.aten.logical_not.default: + return False - # If the logical_not input is not a boolean tensor, bail. - if logical_not_input_node.meta["val"].dtype != torch.bool: - continue + # Get the first arg node and its input + if not isinstance(logical_not_node.args[0], torch.fx.Node): + return False - # Replace the where op with another one, flipping the inputs and using the boolean - # tensor from logical_not. - with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.where.self, - args=(logical_not_node.args[0], node.args[2], node.args[1]), - ) - # Replace all the uses - node.replace_all_uses_with(linear_node) + logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0]) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_node.meta["val"].dtype != torch.bool: + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_logical_nop_where_with_where(graph_module) - result = super().call(graph_module) - return result + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_input_node, node.args[2], node.args[1]), + ) + new_node.meta = node.meta + # Replace all the uses + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep +class ReplaceSafeSoftmaxWithSoftmax(RemoveOrReplacePassInterface): # keep """ Replace _safe_softmax with _softmax """ - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != torch.ops.aten._safe_softmax.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [torch.ops.aten._safe_softmax.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Add False for the half_to_float argument of softmax - softmax_args = list(args) + [False] + softmax_args = tuple(list(node.args) + [False]) - return super().call_operator( - torch.ops.aten._softmax.default, - tuple(softmax_args), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + torch.ops.aten._softmax.default, + args=softmax_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2QuantWithCadenceQuantPass(ExportPass): +class ReplacePT2QuantWithCadenceQuantPass(RemoveOrReplacePassInterface): """ Replace the pt2 quantization ops with cadence quantization ops. We do not link kernels to the PT2 quantization ops, so we need to replace them with cadence ops at all optimization levels. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops - if op != ns.quantized_decomposed.quantize_per_tensor.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + ] - return super().call_operator( - ns.cadence.quantize_per_tensor.default, - args, - kwargs, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + ns.cadence.quantize_per_tensor.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplacePT2DequantWithCadenceDequantPass(ExportPass): +class ReplacePT2DequantWithCadenceDequantPass(RemoveOrReplacePassInterface): """ Replace the pt2 dequantization ops with cadence dequantization ops. We do not link kernels to the PT2 quantization ops, so we need to replace them with cadence ops at all optimization levels. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops - if op != ns.quantized_decomposed.dequantize_per_tensor.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + ] - return super().call_operator( - ns.cadence.dequantize_per_tensor.default, - args, - kwargs, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + ns.cadence.dequantize_per_tensor.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): +class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface): """ When the shape is static, replace squeeze_copy and unsqueeze_copy ops with view_copy op """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, - # which allows us to cover all overloads. - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.squeeze_copy, - exir_ops.edge.aten.unsqueeze_copy, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.squeeze_copy.default, + exir_ops.edge.aten.squeeze_copy.dim, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the output tensor shape - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape # Bail out if any dim is not an int (dynamic shape) for dim in list(out_shape): if not isinstance(dim, int): - return super().call_operator(op, args, kwargs, meta) + return False - # Return a view op with the new shape - view_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta - ) + # Replace with view op with the new shape + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(out_shape)), + ) + # Do not remove the metadata copy! + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceFunctionallyEquivalentOpTargets(ExportPass): +class ReplaceFunctionallyEquivalentOpTargets(RemoveOrReplacePassInterface): """ Replace an op with a functionally equivalent op by just switching the op target, but without incurring any change to the op args. """ - def call_operator(self, op, args, kwargs, meta): - if op not in functionally_equivalent_op_targets: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator( - functionally_equivalent_op_targets[op], args, kwargs, meta - ) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(functionally_equivalent_op_targets.keys()) + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert isinstance(node.target, EdgeOpOverload) + target_op = functionally_equivalent_op_targets[node.target] + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + target_op, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + + # RemoveOrReplacePassInterface calls eliminate_dead_code, but this doesn't + # remove impure nodes (nodes which have side effects). Not sure if that is + # generally safe, so instead of modifying the interface, just erasing + # these nodes for this pass. + node.graph.erase_node(node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -435,14 +449,16 @@ class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass): def call_operator(self, op, args, kwargs, meta): op_packet = get_edge_overload_packet(op) if op_packet not in { - exir_ops.edge.cadence.convolution, + exir_ops.edge.cadence.conv1d, + exir_ops.edge.cadence.conv2d, + exir_ops.edge.cadence.conv3d, exir_ops.edge.cadence.transposed_convolution, }: return super().call_operator(op, args, kwargs, meta) is_transposed = op_packet == exir_ops.edge.cadence.transposed_convolution - expected_args = 9 if is_transposed else 8 - assert len(args) == expected_args + num_expected_args = 9 if is_transposed else 7 + assert len(args) == num_expected_args # Check if the bias is already concrete if args[2] is not None: return super().call_operator(op, args, kwargs, meta) @@ -667,20 +683,28 @@ def call_operator(self, op, args, kwargs, meta): output_padding, groups, ) = args - # Currently we only handle conversion to conv1d and conv2d, therefore + # Currently we only handle conversion to conv1d, conv2d, and conv3d, therefore # verify that the stride, padding, dilation, and output_padding have - # len <=2. + # len <=3. assert ( - len(stride) == len(padding) == len(dilation) == len(output_padding) == 1 - ) or ( - len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 - ), "Can only map convolution to conv1d and conv2d at present" - - target = ( - exir_ops.edge.cadence.transposed_convolution.default - if transposed - else exir_ops.edge.cadence.convolution.default - ) + (len(stride) == len(padding) == len(dilation) == len(output_padding) == 1) + or ( + len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 + ) + or ( + len(stride) == len(padding) == len(dilation) == len(output_padding) == 3 + ) + ), "Can only map convolution to conv1d, conv2d, and conv3d at present" + + # Determine if this is 1D, 2D, or 3D convolution based on parameter lengths + if transposed: + target = exir_ops.edge.cadence.transposed_convolution.default + elif len(stride) == 1: + target = exir_ops.edge.cadence.conv1d.default + elif len(stride) == 2: + target = exir_ops.edge.cadence.conv2d.default + else: # len(stride) == 3 + target = exir_ops.edge.cadence.conv3d.default if transposed: # Flip the height and width dimensions of weight, since we apply a @@ -739,7 +763,6 @@ def call_operator(self, op, args, kwargs, meta): padding, dilation, groups, - False, ) return super().call_operator(target, new_args, kwargs, meta) @@ -761,7 +784,9 @@ class ReplaceTrivialConvWithLinear(ExportPass): """ trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { - exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } @@ -778,7 +803,7 @@ def call_operator(self, op, args, kwargs, meta): op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) - assert (len(args) == 8 and not quantized_op) or ( + assert (len(args) == 7 and not quantized_op) or ( len(args) >= 12 and quantized_op ), "Inconsistent args for convolution" (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] @@ -933,7 +958,9 @@ def call_operator( meta: NodeMetadata, ) -> ProxyValue: if op not in { - exir_ops.edge.cadence.convolution.default, + exir_ops.edge.cadence.conv1d.default, + exir_ops.edge.cadence.conv2d.default, + exir_ops.edge.cadence.conv3d.default, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, }: return super().call_operator(op, args, kwargs, meta) @@ -944,11 +971,11 @@ def call_operator( # Already in NHWC layout. return super().call_operator(op, args, kwargs, meta) - new_op = ( - exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor - if quantized_op - else exir_ops.edge.cadence.convolution.default - ) + if quantized_op: + new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + else: + # Determine if 1D or 2D convolution based on op + new_op = op input_proxy = cast(ProxyValue, args[0]) weight_proxy = cast(ProxyValue, args[1]) @@ -1021,7 +1048,9 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass): # A map from the convolution op to the linear op that it should # decompose to. conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { - exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } @@ -1035,7 +1064,7 @@ def call_operator(self, op, args, kwargs, meta): op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) - assert (len(args) == 8 and not quantized_op) or ( + assert (len(args) == 7 and not quantized_op) or ( len(args) >= 12 and quantized_op ), "Inconsistent args for convolution" (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] @@ -1438,82 +1467,95 @@ def call_operator(self, op, args, kwargs, meta): @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarTensorWithFullPass(ExportPass): +class ReplaceScalarTensorWithFullPass(RemoveOrReplacePassInterface): """ aten.scalar_tensor can be replaced by aten.full with a shape of [1]. scalar_tensor is not supported, so this is an opt_level=0 pass. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.scalar_tensor.default, + @property + def targets(self) -> list[EdgeOpOverload]: + return [ torch.ops.aten.scalar_tensor.default, - }: - return super().call_operator(op, args, kwargs, meta) + exir_ops.edge.aten.scalar_tensor.default, + ] - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - [1], - args[0], - ), - {"dtype": torch.float32}, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=( + [1], + node.args[0], + ), + kwargs={"dtype": torch.float32}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceFullLikeWithFullPass(ExportPass): +class ReplaceFullLikeWithFullPass(RemoveOrReplacePassInterface): """ aten.full_like can be replaced by aten.full with the shape of the arg tensor. full_like is not supported, so this is an opt_level=0 pass. """ - def call_operator(self, op, args, kwargs, meta): - if op not in { - exir_ops.edge.aten.full_like.default, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.full_like.default] - # Get the shape of the "like" tensor, and pass that in to the full op. - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - args[0].to_tensor().shape, - args[1], - ), - {}, - meta, - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + input_arg = node.args[0] + assert isinstance(input_arg, torch.fx.Node) + shape = input_arg.meta["val"].shape + fill_value = node.args[1] + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=(shape, fill_value), + kwargs={}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceInfArgInFullWithValuePass(ExportPass): +class ReplaceInfArgInFullWithValuePass(RemoveOrReplacePassInterface): """ aten.full allows "-inf" and "inf" as inputs. The profiler cannot handle that, so replace them with the maximum value of the type. """ - def call_operator(self, op, args, kwargs, meta): - if op not in { - exir_ops.edge.aten.full.default, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.full.default] - new_args = list(args) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - if args[1] == float("-inf"): + new_args = list(node.args) + fill_value = node.args[1] + if fill_value == float("-inf"): new_args[1] = torch.finfo(torch.float32).min - elif args[1] == float("inf"): + elif fill_value == float("inf"): new_args[1] = torch.finfo(torch.float32).max + else: + return False - return super().call_operator(op, tuple(new_args), kwargs, meta) + new_args = tuple(new_args) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=new_args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -1713,26 +1755,6 @@ def call_operator( return super().call_operator(op, args, kwargs, meta) -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass): - """ - Replace the aten gelu op with an approximate arg with an approximate gelu op. - """ - - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.gelu.default, - }: - return super().call_operator(op, args, kwargs, meta) - return super().call_operator(op, args, kwargs, meta) - - # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py @register_cadence_pass(CadencePassAttribute(opt_level=2)) class ReplaceSplitWithSlicePass(ExportPass): @@ -2122,18 +2144,25 @@ class CommonReplacePasses: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass): +class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(RemoveOrReplacePassInterface): """ Replace aten linalg svd op with cadence custom op. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten._linalg_svd.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten._linalg_svd.default] - return super().call_operator( - exir_ops.edge.cadence.linalg_svd.default, args, kwargs, meta - ) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.linalg_svd.default, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True # This class encapsulates all the functions that replace/switch one op in the @@ -2165,6 +2194,5 @@ class CadenceReplaceOpsInGraph: ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, - ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceMulTensorWithMulAndFullOpsPass, ] diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 5629ed518e5..24bbe7ee644 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1572,25 +1572,6 @@ def test_rope( [[[[5.0]]]], dtype=torch.float32 ), # expected: 1*1 + 4*1 = 5 ), - # Test case 2: Basic 2D convolution (NHWC format) - ( - "basic_2d_nhwc", - torch.tensor( - [[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32 - ), # input: 1x2x2x1 (NHWC) - torch.tensor( - [[[[1.0], [0.0]], [[0.0], [1.0]]]], dtype=torch.float32 - ), # weight: 1x2x2x1 (NHWC format) - torch.tensor([0.0], dtype=torch.float32), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - True, # channel_last - torch.tensor( - [[[[5.0]]]], dtype=torch.float32 - ), # expected: 1*1 + 4*1 = 5 - ), # Test case 3: 2D convolution with stride=2 ( "conv2d_stride2", @@ -1709,23 +1690,6 @@ def test_rope( [[[3.0, 5.0, 7.0]]], dtype=torch.float32 ), # expected: [1+2, 2+3, 3+4] ), - # Test case 8: 1D convolution (NLC format) - ( - "conv1d_nlc", - torch.tensor( - [[[1.0], [2.0], [3.0], [4.0]]], dtype=torch.float32 - ), # input: 1x4x1 (NLC) - torch.tensor( - [[[1.0], [1.0]]], dtype=torch.float32 - ), # weight: 1x2x1 (NLC) - torch.tensor([0.0], dtype=torch.float32), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - True, # channel_last - torch.tensor([[[3.0], [5.0], [7.0]]], dtype=torch.float32), - ), # Test case 9: Multi-channel input and output ( "multi_channel", @@ -1796,19 +1760,31 @@ def test_convolution( padding: tuple[int, int], dilation: tuple[int, int], groups: int, - channel_last: bool, + channel_last: bool, # Keep for backward compatibility with test data, but won't use expected_output: torch.Tensor, ) -> None: - output = torch.ops.cadence.convolution( - input_tensor, - weight, - bias, - stride, - padding, - dilation, - groups, - channel_last, - ) + # Determine if 1D or 2D based on input shape + is_conv1d = len(input_tensor.shape) == 3 + if is_conv1d: + output = torch.ops.cadence.conv1d( + input_tensor, + weight, + bias, + (stride[0],), + (padding[0],), + (dilation[0],), + groups, + ) + else: + output = torch.ops.cadence.conv2d( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + ) # Verify output properties self.assertEqual( @@ -1830,7 +1806,6 @@ def test_convolution( @expand( [ - # Basic 2D transposed convolution with stride=1 (current test case - corrected name) ( "basic_2d_stride1", torch.tensor( @@ -1851,33 +1826,6 @@ def test_convolution( dtype=torch.float32, ), ), - # 2D transposed convolution with channel_last=True (NHWC format) - ( - "channel_last_nhwc", - torch.tensor( - [[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32 - ), # input: 1x2x2x1 (NHWC) - torch.tensor( - [[[[1.0], [1.0]], [[1.0], [1.0]]]], dtype=torch.float32 - ), # weight: 1x2x2x1 (NHWC) - torch.tensor([0.0], dtype=torch.float32), # bias - (1, 1), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - (0, 0), # output_padding - True, # channel_last=True - torch.tensor( - [ - [ - [[1.0], [3.0], [2.0]], - [[4.0], [10.0], [6.0]], - [[3.0], [7.0], [4.0]], - ] - ], - dtype=torch.float32, - ), - ), # 2D transposed convolution with non-zero bias ( "with_bias", @@ -1899,26 +1847,6 @@ def test_convolution( dtype=torch.float32, ), ), - # 1D transposed convolution (3D tensor, NLC format) - ( - "conv1d_nlc", - torch.tensor( - [[[1.0], [2.0], [3.0]]], dtype=torch.float32 - ), # input: 1x3x1 (NLC) - torch.tensor( - [[[1.0], [0.5]]], dtype=torch.float32 - ), # weight: 1x2x1 (NLC) - torch.tensor([0.0], dtype=torch.float32), # bias - (2, 0), # stride - (0, 0), # padding - (1, 1), # dilation - 1, # groups - (0, 0), # output_padding - True, # channel_last=True - torch.tensor( - [[[1.0], [0.5], [2.0], [1.0], [3.0], [1.5]]], dtype=torch.float32 - ), - ), ] ) def test_transposed_convolution( diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 483d737f97d..158ec73cf27 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -196,13 +196,13 @@ def test_remove_zero_sized_constant_pad_nd( ) builder.output([pad]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveZeroSizedConstantPadNd()(original) - ).graph_module + pass_result = cast(PassResult, RemoveZeroSizedConstantPadNd()(original)) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 0, ) + self.assertTrue(pass_result.modified) def test_remove_expand(self) -> None: builder = GraphBuilder() @@ -228,12 +228,12 @@ def test_remove_zero_arg_cat(self) -> None: ) builder.output([concat]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveZeroSizedCatArgsPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveZeroSizedCatArgsPass()(original)) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) + self.assertTrue(pass_result.modified) def test_remove_clone(self) -> None: builder = GraphBuilder() @@ -611,7 +611,9 @@ def test_remove_squeeze_view_before_elemwise_ops(self) -> None: original = deepcopy(model) p = RemoveSqueezeViewBeforeElementwiseOps() - transformed = cast(PassResult, p(model)).graph_module + pass_result = cast(PassResult, p(model)) + self.assertTrue(pass_result.modified) + transformed = pass_result.graph_module # First view should be eliminated and second view should be trivial. views = transformed.graph.find_nodes( @@ -872,9 +874,9 @@ def test_remove_dequant_on_branch(self) -> None: ) builder.output([x1_output, y1_output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveBranchedQuantDequant()(original) - ).graph_module + pass_result = cast(PassResult, RemoveBranchedQuantDequant()(original)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node( graph_after_passes, @@ -904,9 +906,9 @@ def test_remove_cat_from_slice_copy(self) -> None: ) builder.output([output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveCatFromSliceCopyPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveCatFromSliceCopyPass()(original)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) @@ -922,9 +924,9 @@ def test_keep_cat_from_slice_copy(self) -> None: ) builder.output([output]) original = builder.get_graph_module() - graph_after_passes = cast( - PassResult, RemoveCatFromSliceCopyPass()(original) - ).graph_module + pass_result = cast(PassResult, RemoveCatFromSliceCopyPass()(original)) + self.assertFalse(pass_result.modified) + graph_after_passes = pass_result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1 ) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 73964c6c4c4..573489f40b9 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -23,7 +23,6 @@ MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, - ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceConstantPadNdWithSlicePass, @@ -34,6 +33,7 @@ ReplaceFunctionallyEquivalentOpTargets, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, @@ -329,7 +329,9 @@ def test_replace_functionally_equivalent_op_targets_relu( args=(x,), ) p = ReplaceFunctionallyEquivalentOpTargets() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.relu.default), @@ -456,8 +458,9 @@ def test_replace_aten_conv_with_cadence_conv( count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), 0, ) + # This is a 1D convolution (using [stride], [padding], [dilation]) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default), 1, ) self.assertEqual( @@ -545,10 +548,6 @@ def test_replace_aten_transposed_conv_with_cadence_transposed_conv( count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), 0, ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), - 0, - ) self.assertEqual( count_node( graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default @@ -645,19 +644,17 @@ def test_replace_transposed_conv_with_linear( 0, ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default) + + count_node(graph_after_passes, exir_ops.edge.cadence.conv2d.default), 0, ) @expand( [ - [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], + [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False], # # depthwise - [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False, False], - [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False, False], - # channel last (uses a permute op before calling conv1d) - [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], - [(1, 33, 8), 8, 16, 3, 2, 4, 3, True, False, True], + [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False], + [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False], ] ) @torch.no_grad() @@ -672,7 +669,6 @@ def test_replace_convolution_optional_args_with_concrete_args( dilation: int = 1, depthwise: bool = False, bias_enabled: bool = True, - channel_last: bool = False, ) -> None: groups = in_channels if depthwise else 1 builder = GraphBuilder() @@ -688,13 +684,8 @@ def test_replace_convolution_optional_args_with_concrete_args( if bias_enabled else None ) - if channel_last: - x = builder.call_operator( - op=exir_ops.edge.aten.permute_copy.default, - args=(x, [0, 2, 1]), - ) convolution = builder.call_operator( - op=exir_ops.edge.cadence.convolution.default, + op=exir_ops.edge.cadence.conv1d.default, args=( x, weights, @@ -703,14 +694,8 @@ def test_replace_convolution_optional_args_with_concrete_args( [padding], [dilation], groups, - False, ), ) - if channel_last: - convolution = builder.call_operator( - op=exir_ops.edge.aten.permute_copy.default, - args=(convolution, [0, 2, 1]), - ) builder.output([convolution]) original_gm = builder.get_graph_module() p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass() @@ -720,7 +705,7 @@ def test_replace_convolution_optional_args_with_concrete_args( 1, ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default), 1, ) @@ -813,7 +798,9 @@ def test_replace_masked_scalar_tensor_with_full( builder.output([aten_where_self]) original_gm = builder.get_graph_module() p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -837,7 +824,9 @@ def test_replace_scalar_tensor_with_full( args=(0.123,), ) p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -983,7 +972,12 @@ def test_replace_squeeze_with_view( args=(x,), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -1018,7 +1012,12 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: args=(x, dim), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -1029,6 +1028,28 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: 0, ) + @torch.no_grad() + def test_replace_squeeze_and_unsqueeze_with_view_no_modification(self) -> None: + """Negative test: pass doesn't modify graphs without squeeze/unsqueeze ops.""" + x = torch.randn(2, 3, 4) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.view_copy.default, + args=(x, [2, 12]), + ) + p = ReplaceSqueezeAndUnsqueezeWithViewPass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass did NOT modify the graph + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Verify the original view_copy operation is still there + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), + 1, + ) + @torch.no_grad() def test_replace_conv1d_with_linear(self) -> None: x = torch.randn(1, 96, 7) @@ -1036,21 +1057,16 @@ def test_replace_conv1d_with_linear(self) -> None: bias = torch.randn(192) original_gm = single_op_builder( placeholders=(x, weights, bias), - op=exir_ops.edge.cadence.convolution.default, - args=(x, weights, bias, [1], [0], [1], 1, False), + op=exir_ops.edge.cadence.conv1d.default, + args=(x, weights, bias, [1], [0], [1], 1), ) - # First, replace the aten convolution with a cadence.convolution op - p1 = ReplaceAtenConvolutionWithCadenceConvolutionPass() - temp_graph = cast(PassResult, p1(original_gm)).graph_module - # temp_graph = p1(original_gm).graph_module - self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module + graph_after_passes = cast(PassResult, p2(original_gm)).graph_module # Assert that conv1d is trivially converted to linear self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + count_node(graph_after_passes, exir_ops.edge.cadence.conv1d.default), 0 ) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 @@ -1070,20 +1086,16 @@ def test_replace_conv2d_with_linear(self) -> None: bias = torch.randn(192) original_gm = single_op_builder( placeholders=(x, weights, bias), - op=exir_ops.edge.cadence.convolution.default, - args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1, False), + op=exir_ops.edge.cadence.conv2d.default, + args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1), ) - # First, replace the aten convolution with a cadence.convolution op - p1 = ReplaceAtenConvolutionWithCadenceConvolutionPass() - temp_graph = cast(PassResult, p1(original_gm)).graph_module - self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module + graph_after_passes = cast(PassResult, p2(original_gm)).graph_module # Assert that conv2d is trivially converted to linear self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + count_node(graph_after_passes, exir_ops.edge.cadence.conv2d.default), 0 ) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 @@ -1103,15 +1115,15 @@ def test_replace_conv2d_with_im2row_and_linear(self) -> None: bias = torch.randn(192) original_gm = single_op_builder( placeholders=(x, weights, bias), - op=exir_ops.edge.cadence.convolution.default, - args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1, False), + op=exir_ops.edge.cadence.conv2d.default, + args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1), ) p = ReplaceConvWithIm2RowAndLinear() graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that the convolution is converted to im2row + linear self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 + count_node(graph_after_passes, exir_ops.edge.cadence.conv2d.default), 0 ) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.cadence.im2row.per_tensor), 1 @@ -1339,28 +1351,6 @@ def test_replace_aten_where_with_cadence_broadcast( 1, ) - def test_no_replace_aten_gelu_with_approximate_gelu(self) -> None: - inputs = torch.randn(2, 1, 64) - - gm = single_op_builder( - placeholders=(inputs,), - op=exir_ops.edge.aten.gelu.default, - args=(inputs,), - ) - gm = ExportPass().call(gm).graph_module - - p = ReplaceAtenApproxGeluWithApproxGeluPass() - graph_after_passes = p.call(gm).graph_module - - # Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument - self.assertEqual( - count_node( - graph_after_passes, - exir_ops.edge.aten.gelu.default, - ), - 1, - ) - def test_replace_split_with_sizes_with_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) @@ -1547,62 +1537,10 @@ def create_conv1d_graphmodule( args = args + (channels_last,) return single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.convolution.default, + op=exir_ops.edge.cadence.conv1d.default, args=args, ) - def test_conv1d_default_channel_last(self) -> None: - # Create a graph with a single convolution node. - # Check if graph module is valid by running exportpass on it. - gm = self.create_conv1d_graphmodule() - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), - # Two transposes are added, one for the input and one for the output. - 3, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - - def test_conv1d_no_transpose_if_already_channel_last(self) -> None: - gm = self.create_conv1d_graphmodule(channels_last=True) - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), - 0, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - def create_convolution_graph_module( self, channels_last: Optional[bool] = None ) -> torch.fx.GraphModule: @@ -1624,62 +1562,10 @@ def create_convolution_graph_module( args = args + (channels_last,) return single_op_builder( placeholders=(x, w, b), - op=exir_ops.edge.cadence.convolution.default, + op=exir_ops.edge.cadence.conv2d.default, args=args, ) - def test_convolution_default_channel_last(self) -> None: - # Create a graph with a single convolution node. - # Check if graph module is valid by running exportpass on it. - gm = self.create_convolution_graph_module() - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - # Three permutes are added, two for the input/weights and one for the output. - 3, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - - def test_no_transpose_if_already_channel_last(self) -> None: - gm = self.create_convolution_graph_module(channels_last=True) - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) - - # Apply replacement pass. - p = ReplaceConvWithChannelLastConvPass() - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), - 1, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - 0, - ) - for node in gm_after_replacement.graph.nodes: - if node.target != exir_ops.edge.cadence.convolution.default: - continue - # Check that the channel_last argument is set to True. - self.assertEqual(len(node.args), 8, f"{node=}") - self.assertTrue(node.args[7]) - def create_quantized_convolution_graph_module( self, channels_last: Optional[bool] = None ) -> tuple[tuple[torch.Tensor, ...], torch.fx.GraphModule]: @@ -2142,7 +2028,9 @@ def test_replace_aten_linalg_svd_with_cadence_linalg_svd( ) p = ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module # Assert that the aten linalg_svd op was replaced with cadence linalg_svd op self.assertEqual( @@ -2198,3 +2086,114 @@ def test_replace_quantized_embedding( ), 1, ) + + +class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase): + """Tests for the ReplaceLogicalNotBooleanWhereWithWherePass.""" + + def test_replace_where_with_logical_not_boolean(self) -> None: + """Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x).""" + # Setup: Create a graph with where(logical_not(bool_cond), x, y) + builder = GraphBuilder() + bool_cond_ = torch.randn(4, 8) > 0 + x_ = torch.randn(4, 8) + y_ = torch.randn(4, 8) + + bool_cond = builder.placeholder("bool_cond", bool_cond_) + x = builder.placeholder("x", x_) + y = builder.placeholder("y", y_) + + # Create logical_not node + logical_not = builder.call_operator( + op=exir_ops.edge.aten.logical_not.default, + args=(bool_cond,), + ) + + # Create where node using logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(logical_not, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Make a copy of the original graph before applying the pass + original_gm_copy = copy.deepcopy(original_gm) + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify logical_not is removed (dead code elimination) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), + 0, + ) + + # Assert: Verify where node still exists + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + # Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped) + where_nodes = list( + graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ) + ) + for node in where_nodes: + # First arg should be the original bool_cond (not the logical_not) + self.assertEqual(node.args[0].name, "bool_cond") + # Second and third args should be swapped (y, x instead of x, y) + self.assertEqual(node.args[1].name, "y") + self.assertEqual(node.args[2].name, "x") + + # Assert: Verify outputs match exactly by running both graphs + validate( + original_gm_copy, + graph_after_passes, + (bool_cond_, x_, y_), + "ReplaceLogicalNotBooleanWhereWithWherePass", + ) + + def test_no_replacement_without_logical_not(self) -> None: + """Test that the pass does NOT apply when there's no logical_not.""" + # Setup: Create a graph with where(bool_cond, x, y) without logical_not + builder = GraphBuilder() + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(4, 8)) + + # Create where node directly without logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(bool_cond, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass did NOT modify the graph + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify where node still exists unchanged + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + for node in graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ): + self.assertEqual(node.args[0].name, "bool_cond") + self.assertEqual(node.args[1].name, "x") + self.assertEqual(node.args[2].name, "y") diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp index b78cc33890b..7ef7bdde3b1 100644 --- a/backends/cadence/fusion_g3/operators/op_add.cpp +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -162,7 +162,7 @@ Tensor& add_out( float alpha_val; torch::executor::native::utils::extract_scalar(alpha, &alpha_val); - if ((a.numel() == 1) && (alpha_val == 1.0)) { + if ((a.numel() == 1) && (alpha_val == 1.0f)) { XT_KERNEL_CHECK( ctx, out, diff --git a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp index a97f9beb0c7..504a00fcaee 100644 --- a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp +++ b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp @@ -123,7 +123,7 @@ Tensor& slice_copy_Tensor_out( InvalidArgument, out); - torch::executor::compute_slice(in, dim, start, length, step, out); + torch::executor::compute_slice(ctx, in, dim, start, length, step, out); } return out; diff --git a/backends/cadence/hifi/kernels/kernels.cpp b/backends/cadence/hifi/kernels/kernels.cpp index 237c605443f..07f0ac960b1 100644 --- a/backends/cadence/hifi/kernels/kernels.cpp +++ b/backends/cadence/hifi/kernels/kernels.cpp @@ -39,8 +39,8 @@ void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) { template __attribute__((always_inline)) T quantize(const float x, float scale, int32_t zero_point) { - constexpr float min_val = std::numeric_limits::min(); - constexpr float max_val = std::numeric_limits::max(); + constexpr float min_val = static_cast(std::numeric_limits::min()); + constexpr float max_val = static_cast(std::numeric_limits::max()); float tmp = roundf(x * scale + zero_point); return std::max(std::min(tmp, max_val), min_val); } @@ -56,8 +56,8 @@ void quantize( xtfloatx2 scale_vec = (xtfloatx2)scale; xtfloatx2 zero_vec = XT_FLOAT_SX2(zero_point, 0); - constexpr float min_val = std::numeric_limits::min(); - constexpr float max_val = std::numeric_limits::max(); + constexpr float min_val = static_cast(std::numeric_limits::min()); + constexpr float max_val = static_cast(std::numeric_limits::max()); const xtfloatx2* __restrict__ p0 = (const xtfloatx2* __restrict__)x; ae_valign va0 = XT_LASX2PP(p0); diff --git a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp index d3dcce1d5f4..80c02a79e93 100644 --- a/backends/cadence/hifi/operators/op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp @@ -6,17 +6,18 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include -using executorch::aten::ScalarType; -using executorch::aten::Tensor; -using torch::executor::KernelRuntimeContext; - namespace impl { namespace HiFi { namespace native { +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + void quantized_relu_per_tensor_out( KernelRuntimeContext& ctx, const Tensor& input, @@ -34,7 +35,10 @@ void quantized_relu_per_tensor_out( const uint8_t* p_in = input.const_data_ptr(); uint8_t* p_out = output.mutable_data_ptr(); - WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u( + XT_KERNEL_CHECK( + ctx, + , + xa_nn_vec_relu_asym8u_asym8u, p_out, p_in, _in_zero_point, @@ -45,15 +49,16 @@ void quantized_relu_per_tensor_out( 255, input.numel()); - ET_CHECK_MSG(ret_val == 0, "An internal error occured"); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - const int8_t _in_zero_point = static_cast(in_zero_point); - const int8_t _out_zero_point = static_cast(out_zero_point); + const int _in_zero_point = static_cast(in_zero_point); + const int _out_zero_point = static_cast(out_zero_point); const int8_t* p_in = input.const_data_ptr(); int8_t* p_out = output.mutable_data_ptr(); - WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s( + XT_KERNEL_CHECK( + ctx, + , + xa_nn_vec_relu_asym8s_asym8s, p_out, p_in, _in_zero_point, @@ -64,8 +69,6 @@ void quantized_relu_per_tensor_out( 127, input.numel()); - ET_CHECK_MSG(ret_val == 0, "An internal error occured"); - } else { ET_CHECK_MSG( false, diff --git a/backends/cadence/hifi/operators/op_slice_copy.cpp b/backends/cadence/hifi/operators/op_slice_copy.cpp index 014eaa6698b..ff447461d6e 100644 --- a/backends/cadence/hifi/operators/op_slice_copy.cpp +++ b/backends/cadence/hifi/operators/op_slice_copy.cpp @@ -64,7 +64,7 @@ Tensor& slice_copy_Tensor_out( InvalidArgument, out); - compute_slice(in, dim, start, length, step, out); + compute_slice(ctx, in, dim, start, length, step, out); return out; } diff --git a/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp b/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp index dc9a197a504..a599d73ccc8 100644 --- a/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp +++ b/backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp @@ -57,14 +57,14 @@ class HiFiQuantizedReluTest : public OperatorTest { TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) { TensorFactory tf_chars; + TensorFactory tf_ints; const std::vector sizes{2, 3, 5, 6}; Tensor quantized_input = tf_chars.full(sizes, -128); Tensor quantized_output = tf_chars.full(sizes, 100); Tensor in_zero_point = tf_chars.full({1}, 127); int64_t out_zero_point = -128; - Tensor out_multiplier = - TensorFactory().full({1}, 1077952640); - Tensor out_shift = TensorFactory().full({1}, 5); + Tensor out_multiplier = tf_ints.full({1}, 1077952640); + Tensor out_shift = tf_ints.full({1}, 5); quantized_relu_out( quantized_input, @@ -80,14 +80,14 @@ TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) { TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) { TensorFactory tf_chars; + TensorFactory tf_ints; const std::vector sizes{56}; Tensor quantized_input = tf_chars.full(sizes, -128); Tensor quantized_output = tf_chars.full(sizes, 100); Tensor in_zero_point = tf_chars.full({1}, 127); int64_t out_zero_point = -128; - Tensor out_multiplier = - TensorFactory().full({1}, 1077952640); - Tensor out_shift = TensorFactory().full({1}, 5); + Tensor out_multiplier = tf_ints.full({1}, 1077952640); + Tensor out_shift = tf_ints.full({1}, 5); quantized_relu_out( quantized_input, diff --git a/backends/cadence/hifi/third-party/nnlib/targets.bzl b/backends/cadence/hifi/third-party/nnlib/targets.bzl index a63a4dd3954..2ad9d6568ac 100644 --- a/backends/cadence/hifi/third-party/nnlib/targets.bzl +++ b/backends/cadence/hifi/third-party/nnlib/targets.bzl @@ -13,6 +13,10 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], compatible_with = ["ovr_config//cpu:xtensa"], + compiler_flags = [ + "-Wno-pointer-sign", + "-Wno-incompatible-pointer-types-discards-qualifiers", + ], deps = [ "fbsource//third-party/nnlib-hifi4/xa_nnlib:libxa_nnlib", ], diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c index 2f1d2071777..68a51223cde 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c @@ -21,7 +21,7 @@ ******************************************************************************/ #include -#include "../include/NatureDSP_Signal_math.h" +#include "NatureDSP_Signal_math.h" #include "NatureDSP_types.h" #include "xa_nn_common.h" diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c index aa81d695784..5fb69113ee7 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c @@ -20,7 +20,7 @@ ******************************************************************************/ -#include "../include/NatureDSP_Signal_math.h" +#include "NatureDSP_Signal_math.h" #include "NatureDSP_types.h" #include "xa_nn_common.h" diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c index e7e83846484..840a027f7a7 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c @@ -117,6 +117,7 @@ WORD32 xa_nn_elm_where_f32xf32_f32(FLOAT32 * __restrict__ p_out, XT_MOVF_S(a, a2, s); XT_SSI(a, (xtfloat *)out, 0); } + return 0; } static void internal_elm_where_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out, diff --git a/backends/cadence/runtime/runtime.py b/backends/cadence/runtime/runtime.py index a7d35fbd0c9..3a139e415ea 100644 --- a/backends/cadence/runtime/runtime.py +++ b/backends/cadence/runtime/runtime.py @@ -45,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[ op_names |= get_op_names( deserialize_pte_binary( program.backend_delegate_data[delegate.processed.index].data - ) + ).program ) return op_names diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index e49cf412c19..b5c5683ab5d 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -15,6 +15,7 @@ import torch from facto.inputgen.argtuple.gen import ArgumentTupleGenerator from facto.inputgen.specs.model import ConstraintProducer as cp +from facto.inputgen.utils.random_manager import seeded_random_manager as rm from facto.inputgen.variable.type import ScalarDtype from facto.specdb.db import SpecDictDB @@ -26,6 +27,33 @@ _shape_cache: dict[str, list[int]] = {} +def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int, ...]]: + """ + Generate valid permutations using only positive dimension indices. + This is required for Cadence/Xtensa kernels that don't support negative indexing. + + Args: + tensor: Input tensor to generate permutations for + length: Number of dimensions in the permutation (must equal tensor.dim()) + + Returns: + Set of valid permutation tuples containing only positive indices [0, rank-1] + """ + if length > tensor.dim(): + return set() + + n = tensor.dim() + pool = list(range(n)) + + # Generate multiple valid permutations (only positive indices) + permutations: set[tuple[int, ...]] = set() + for _ in range(3): # Generate 3 different permutations for diversity + perm = tuple(rm.get_random().sample(pool, length)) + permutations.add(perm) + + return permutations + + def apply_tensor_contraints(op_name: str, index: int) -> list[object]: # Constraint to limit tensor size to < 4000 bytes with fully randomized shapes import random @@ -161,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: if index == 0: # condition tensor_constraints = [ cp.Dtype.In(lambda deps: [torch.bool]), - cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), - cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Value.Ge(lambda deps, dtype, struct: 0), + cp.Value.Le(lambda deps, dtype, struct: 1), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), max_size_constraint, ] elif index == 1: # input tensor(a) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d) + ), max_size_constraint, ] else: # input tensor(b) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), cp.Dtype.Eq(lambda deps: deps[1].dtype), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with( + fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d + ) + ), max_size_constraint, ] case "embedding.default": @@ -248,6 +266,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float32, torch.int32]), + # Avoid NaN/Inf values that expose clamp NaN handling bugs + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), ] ) case "rsqrt.default": @@ -323,12 +344,15 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: ] ) case "constant_pad_nd.default": - tensor_constraints.extend( - [ - cp.Dtype.In(lambda deps: [torch.float32]), - cp.Size.Le(lambda deps, r, d: 2**2), - ] - ) + tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.float32]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Rank.Le(lambda deps: 2), # Reduced from 3 to 2 (max 2D tensors) + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 3), # Max dimension size of 3 + ] case "avg_pool2d.default": tensor_constraints.extend( [ @@ -344,14 +368,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: ] ) case "div.Tensor": - tensor_constraints.extend( - [ - cp.Value.Ne(lambda deps, dtype, struct: 0), - cp.Value.Le(lambda deps, dtype, struct: 2**3), - cp.Size.Le(lambda deps, r, d: 2**3), - cp.Rank.Le(lambda deps: 2**2), - ] - ) + if index == 1: # Only apply zero-prevention to divisor + tensor_constraints.extend( + [ + cp.Value.Ne( + lambda deps, dtype, struct: 0 + ), # Prevent division by zero + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), + ] + ) + else: + tensor_constraints.extend( + [ + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), + ] + ) case "pow.Tensor_Scalar": tensor_constraints.extend( [ @@ -373,6 +408,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Value.Ne( + lambda deps, dtype, struct: 0 + ), # Prevent division by zero cp.Rank.Ge(lambda deps: 1), cp.Rank.Eq(lambda deps: deps[0].dim()), cp.Size.Eq(lambda deps, r, d: fn.safe_size(deps[0], d)), @@ -389,6 +427,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: cp.Value.Le(lambda deps, dtype, struct: 2**2), cp.Size.Le(lambda deps, r, d: 2**3), ] + case "leaky_relu.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case "_softmax.default": tensor_constraints.extend( [ @@ -396,6 +440,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: cp.Size.Le(lambda deps, r, d: 2**2), ] ) + case "flip.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case _: pass return tensor_constraints @@ -409,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: | "mul.Scalar" | "div.Scalar" | "constant_pad_nd.default" + | "clamp.default" ): return [ScalarDtype.int] case "full.default": @@ -436,11 +487,44 @@ def facto_testcase_gen( # noqa: C901 cp.Size.Le(lambda deps, r, d: 2**2), ] ) - if in_spec.name == "max_val": # hardtanh + # Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None + if op_name == "clamp.default": + if in_spec.name == "min": + # min must always be provided (not None) and bounded, leave room for max + spec.inspec[index].constraints.extend( + [ + cp.Optional.Eq(lambda deps: False), # Never None + cp.Value.Ge(lambda deps, dtype: -(2**4)), + cp.Value.Le( + lambda deps, dtype: 2**4 - 2 + ), # Leave room for max (at least 2 units) + ] + ) + elif in_spec.name == "max": + # max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded + spec.inspec[index].deps = [0, 1] # deps on input tensor and min + spec.inspec[index].constraints.extend( + [ + cp.Optional.Eq(lambda deps: False), # Never None + cp.Value.Ge( + lambda deps, dtype: deps[1] + 2 + ), # max >= min + 2 (sufficient gap) + cp.Value.Le(lambda deps, dtype: 2**4), + ] + ) + elif in_spec.name == "max_val": # hardtanh spec.inspec[index].deps = [0, 1] spec.inspec[index].constraints.extend( [cp.Value.Ge(lambda deps, _: deps[1])] ) + elif in_spec.name == "negative_slope" and op_name == "leaky_relu.default": + # For leaky_relu, negative_slope should be in typical range (0, 1] + spec.inspec[index].constraints.extend( + [ + cp.Value.Gt(lambda deps, dtype: 0), + cp.Value.Le(lambda deps, dtype: 1.0), + ] + ) else: spec.inspec[index].constraints.extend( [ @@ -465,12 +549,32 @@ def facto_testcase_gen( # noqa: C901 apply_tensor_contraints(op_name, index) ) elif in_spec.type.is_dim_list(): - spec.inspec[index].constraints.extend( - [ - cp.Length.Ge(lambda deps: 1), - cp.Optional.Eq(lambda deps: False), - ] - ) + # Special handling for permute_copy.default to ensure valid permutation + if op_name == "permute_copy.default": + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Length.Eq( + lambda deps: deps[0].dim() + ), # Must be a complete permutation + cp.Optional.Eq(lambda deps: False), + # Generate valid permutations using only positive indices + # Cadence/Xtensa hardware kernels do not support negative dimension indices + cp.Value.Gen( + lambda deps, length: ( + _positive_valid_dim_list(deps[0], length), + fn.invalid_dim_list(deps[0], length), + ) + ), + ] + ) + else: + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Optional.Eq(lambda deps: False), + ] + ) elif in_spec.type.is_bool(): spec.inspec[index].constraints.extend( [ diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index a728584e49c..ac330d4b015 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -56,7 +56,12 @@ set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp ) # Generate C++ bindings to register kernels into Executorch diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index c7e6cc8a389..716b53cdcf8 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -18,10 +18,19 @@ #include #include +#include +#include + +extern "C" { +#include "arm_nn_types.h" +} + using Tensor = torch::executor::Tensor; using ScalarType = executorch::aten::ScalarType; using Scalar = torch::executor::Scalar; using Error = executorch::runtime::Error; +using IntArrayRef = executorch::aten::ArrayRef; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; // From arm_nn_math_types.h #define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL)) @@ -33,7 +42,8 @@ inline void validate_cmsis_nn_tensor_requirements( const Tensor& input2, Tensor& output, ScalarType expected_dtype = ScalarType::Char, - bool require_channels_last = false) { + bool require_channels_last = false, + bool require_same_sizes = true) { // Basic dtype validation ET_CHECK_MSG( input1.scalar_type() == expected_dtype, @@ -50,12 +60,14 @@ inline void validate_cmsis_nn_tensor_requirements( "Output dtype must be %hhd, got %hhd", expected_dtype, output.scalar_type()); - ET_CHECK_MSG( - input1.sizes() == input2.sizes(), - "Input1 and Input2 must have the same sizes"); - ET_CHECK_MSG( - output.sizes() == input1.sizes(), - "Output must have the same sizes as inputs"); + if (require_same_sizes) { + ET_CHECK_MSG( + input1.sizes() == input2.sizes(), + "Input1 and Input2 must have the same sizes"); + ET_CHECK_MSG( + output.sizes() == input1.sizes(), + "Output must have the same sizes as inputs"); + } // Dim order consistency ET_CHECK_MSG( diff --git a/backends/cortex_m/ops/op_maximum.cpp b/backends/cortex_m/ops/op_maximum.cpp new file mode 100644 index 00000000000..71a907f12ea --- /dev/null +++ b/backends/cortex_m/ops/op_maximum.cpp @@ -0,0 +1,102 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& maximum_out( + KernelRuntimeContext& context, + const Tensor& input1, + const Tensor& input2, + Tensor& out) { + validate_cmsis_nn_tensor_requirements( + input1, + input2, + out, + ScalarType::Char, + /*require_channels_last=*/false, + /*require_same_sizes=*/false); + + auto resize_error = resize_to_broadcast_target_size(input1, input2, out); + if (resize_error != Error::Ok) { + ET_LOG(Error, "maximum_out: broadcast shape mismatch between inputs"); + context.fail(resize_error); + return out; + } + + const int8_t* input1_data = input1.const_data_ptr(); + const int8_t* input2_data = input2.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + // Create CMSIS-NN dims directly from tensor sizes + const auto input1_rank = input1.dim(); + const auto input1_sizes = input1.sizes(); + const cmsis_nn_dims input1_dims{ + static_cast( + input1_rank >= 4 ? input1_sizes[input1_rank - 4] : 1), + static_cast( + input1_rank >= 3 ? input1_sizes[input1_rank - 3] : 1), + static_cast( + input1_rank >= 2 ? input1_sizes[input1_rank - 2] : 1), + static_cast( + input1_rank >= 1 ? input1_sizes[input1_rank - 1] : 1)}; + + const auto input2_rank = input2.dim(); + const auto input2_sizes = input2.sizes(); + const cmsis_nn_dims input2_dims{ + static_cast( + input2_rank >= 4 ? input2_sizes[input2_rank - 4] : 1), + static_cast( + input2_rank >= 3 ? input2_sizes[input2_rank - 3] : 1), + static_cast( + input2_rank >= 2 ? input2_sizes[input2_rank - 2] : 1), + static_cast( + input2_rank >= 1 ? input2_sizes[input2_rank - 1] : 1)}; + + const auto output_rank = out.dim(); + const auto output_sizes = out.sizes(); + const cmsis_nn_dims output_dims{ + static_cast( + output_rank >= 4 ? output_sizes[output_rank - 4] : 1), + static_cast( + output_rank >= 3 ? output_sizes[output_rank - 3] : 1), + static_cast( + output_rank >= 2 ? output_sizes[output_rank - 2] : 1), + static_cast( + output_rank >= 1 ? output_sizes[output_rank - 1] : 1)}; + + const arm_cmsis_nn_status status = arm_maximum_s8( + /* ctx */ nullptr, + input1_data, + &input1_dims, + input2_data, + &input2_dims, + output_data, + &output_dims); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "maximum_out: arm_maximum_s8 failed with status [%d]", + static_cast(status)); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_minimum.cpp b/backends/cortex_m/ops/op_minimum.cpp new file mode 100644 index 00000000000..f220aa2664b --- /dev/null +++ b/backends/cortex_m/ops/op_minimum.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& minimum_out( + KernelRuntimeContext& context, + const Tensor& input1, + const Tensor& input2, + Tensor& out) { + validate_cmsis_nn_tensor_requirements( + input1, + input2, + out, + ScalarType::Char, + /*require_channels_last=*/false, + /*require_same_sizes=*/false); + + auto resize_error = resize_to_broadcast_target_size(input1, input2, out); + if (resize_error != Error::Ok) { + ET_LOG(Error, "minimum_out: broadcast shape mismatch between inputs"); + context.fail(resize_error); + return out; + } + + const int8_t* input1_data = input1.const_data_ptr(); + const int8_t* input2_data = input2.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + // Create CMSIS-NN dims directly from tensor sizes + const auto input1_rank = input1.dim(); + const auto input1_sizes = input1.sizes(); + const cmsis_nn_dims input1_dims{ + static_cast( + input1_rank >= 4 ? input1_sizes[input1_rank - 4] : 1), + static_cast( + input1_rank >= 3 ? input1_sizes[input1_rank - 3] : 1), + static_cast( + input1_rank >= 2 ? input1_sizes[input1_rank - 2] : 1), + static_cast( + input1_rank >= 1 ? input1_sizes[input1_rank - 1] : 1)}; + + const auto input2_rank = input2.dim(); + const auto input2_sizes = input2.sizes(); + const cmsis_nn_dims input2_dims{ + static_cast( + input2_rank >= 4 ? input2_sizes[input2_rank - 4] : 1), + static_cast( + input2_rank >= 3 ? input2_sizes[input2_rank - 3] : 1), + static_cast( + input2_rank >= 2 ? input2_sizes[input2_rank - 2] : 1), + static_cast( + input2_rank >= 1 ? input2_sizes[input2_rank - 1] : 1)}; + + const auto output_rank = out.dim(); + const auto output_sizes = out.sizes(); + const cmsis_nn_dims output_dims{ + static_cast( + output_rank >= 4 ? output_sizes[output_rank - 4] : 1), + static_cast( + output_rank >= 3 ? output_sizes[output_rank - 3] : 1), + static_cast( + output_rank >= 2 ? output_sizes[output_rank - 2] : 1), + static_cast( + output_rank >= 1 ? output_sizes[output_rank - 1] : 1)}; + + const arm_cmsis_nn_status status = arm_minimum_s8( + /* ctx */ nullptr, + input1_data, + &input1_dims, + input2_data, + &input2_dims, + output_data, + &output_dims); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "minimum_out: arm_minimum_s8 failed with status [%d]", + static_cast(status)); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index 30be108ffcb..ddc4b4bb869 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -78,6 +78,15 @@ Tensor& quantized_add_out( output_mult, output_shift_val); + // Note 1: The CMSIS-NN kernel implementation uses offsets which are always + // added to the data, whereas zero_points are subtracted when dequantizing + // (for the inputs) and added when quantizing (for the output). Hence the + // negative signs required here. + + // Note 2: It is not possible to perform the same rewrite as for mul for + // addition. To preserve precision when rescaling the inputs, they are first + // upscaled as much as possible, Hence the left_shift parameter required here. + // Call CMSIS-NN kernel with precomputed parameters arm_cmsis_nn_status status = arm_elementwise_add_s8( input1_int8.const_data_ptr(), diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp new file mode 100644 index 00000000000..ad14af98865 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -0,0 +1,236 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { +constexpr int64_t kConvDim = 4; + +bool validate_conv2d_arguments( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const Tensor& output, + const IntArrayRef& stride, + const IntArrayRef& padding, + const IntArrayRef& dilation, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts) { + if (input.dim() != kConvDim || weight.dim() != kConvDim || + output.dim() != kConvDim) { + ET_LOG(Error, "quantized_conv2d_out: tensors must be 4-D"); + context.fail(Error::InvalidArgument); + return false; + } + + // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) + // Skip check if channels == 1, as dim_order is ambiguous in that case + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { + 0, 2, 3, 1}; + executorch::aten::ArrayRef + channels_last_order(kChannelsLastDimOrder, 4); + + if (input.size(1) > 1 && input.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_conv2d_out: input must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (output.size(1) > 1 && output.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_conv2d_out: output must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (input.scalar_type() != ScalarType::Char || + output.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_conv2d_out: input and output must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (weight.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_conv2d_out: weight must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { + ET_LOG(Error, "quantized_conv2d_out: bias must be int32 if provided"); + context.fail(Error::InvalidArgument); + return false; + } + + if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { + ET_LOG( + Error, + "quantized_conv2d_out: stride, padding, and dilation must have length 2"); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t out_channels = output.size(1); + if (requantize_multipliers.size(0) != out_channels || + requantize_shifts.size(0) != out_channels) { + ET_LOG( + Error, + "quantized_conv2d_out: per-channel params must match output channels (%zd)", + out_channels); + context.fail(Error::InvalidArgument); + return false; + } + + return true; +} +} // namespace + +Tensor& quantized_conv2d_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t input_offset, + const int64_t output_offset, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts, + const int64_t activation_min, + const int64_t activation_max, + Tensor& out) { + if (!validate_conv2d_arguments( + context, + input, + weight, + bias, + out, + stride, + padding, + dilation, + requantize_multipliers, + requantize_shifts)) { + return out; + } + + const int32_t batch = static_cast(input.size(0)); + const int32_t input_channels = static_cast(input.size(1)); + const int32_t input_height = static_cast(input.size(2)); + const int32_t input_width = static_cast(input.size(3)); + + const int32_t kernel_output_channels = static_cast(weight.size(0)); + const int32_t kernel_height = static_cast(weight.size(1)); + const int32_t kernel_width = static_cast(weight.size(2)); + const int32_t kernel_input_channels = static_cast(weight.size(3)); + + const int32_t output_channels = static_cast(out.size(1)); + const int32_t output_height = static_cast(out.size(2)); + const int32_t output_width = static_cast(out.size(3)); + + const int32_t input_offset_val = static_cast(input_offset); + const int32_t output_offset_val = static_cast(output_offset); + const int32_t activation_min_val = static_cast(activation_min); + const int32_t activation_max_val = static_cast(activation_max); + + const cmsis_nn_dims input_dims{ + batch, input_height, input_width, input_channels}; + const cmsis_nn_dims filter_dims{ + kernel_output_channels, + kernel_height, + kernel_width, + kernel_input_channels}; + const cmsis_nn_dims output_dims{ + batch, output_height, output_width, output_channels}; + const cmsis_nn_dims bias_dims{1, 1, 1, output_channels}; + const cmsis_nn_dims upscale_dims{1, 1, 1, 1}; + + cmsis_nn_conv_params conv_params; + conv_params.input_offset = input_offset_val; + conv_params.output_offset = output_offset_val; + conv_params.stride.h = static_cast(stride[0]); + conv_params.stride.w = static_cast(stride[1]); + conv_params.padding.h = static_cast(padding[0]); + conv_params.padding.w = static_cast(padding[1]); + conv_params.dilation.h = static_cast(dilation[0]); + conv_params.dilation.w = static_cast(dilation[1]); + conv_params.activation.min = activation_min_val; + conv_params.activation.max = activation_max_val; + + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = requantize_multipliers.data_ptr(); + quant_params.shift = requantize_shifts.data_ptr(); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weight.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + + cmsis_nn_context cmsis_context; + cmsis_context.buf = nullptr; + cmsis_context.size = 0; + + const size_t buffer_bytes = static_cast( + arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims)); + if (buffer_bytes > 0) { + auto buffer_or_error = + context.allocate_temp(buffer_bytes, alignof(int16_t)); + if (!buffer_or_error.ok()) { + if (buffer_or_error.error() != Error::NotFound) { + ET_LOG( + Error, + "quantized_conv2d_out: failed to allocate scratch buffer (%d)", + static_cast(buffer_or_error.error())); + context.fail(buffer_or_error.error()); + return out; + } + } else { + cmsis_context.buf = buffer_or_error.get(); + cmsis_context.size = buffer_bytes; + } + } + + const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( + &cmsis_context, + &conv_params, + &quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_conv2d_out: arm_convolve_s8 failed with status %d", + status); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index d1ccb6d0d45..015fa805134 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -1,12 +1,12 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "cmsis_scratch_buffer_context.h" #include "cortex_m_ops_common.h" extern "C" { @@ -20,151 +20,90 @@ using KernelRuntimeContext = torch::executor::KernelRuntimeContext; Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, - const Scalar& input_zero_point, - const Scalar& input_multiplier, - const Scalar& input_shift, const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features, + const torch::executor::optional& kernel_sum, + const Scalar& input_offset, + const Scalar& filter_offset, + const Scalar& output_offset, + const IntArrayRef requantize_multipliers, + const IntArrayRef requantize_shifts, + const Scalar& activation_max, + const Scalar& activation_min, Tensor& out) { ET_LOG(Info, "quantized_linear_out: called"); - validate_cmsis_nn_tensor_requirements(input, weights, out); - - ET_CHECK_MSG( - scratch_buffer.scalar_type() == ScalarType::Char, - "Scratch buffer must be int8"); - - const int32_t batch_size = input.size(0); - const int32_t in_feat = static_cast(in_features.to()); - const int32_t out_feat = static_cast(out_features.to()); - const int32_t input_zp = static_cast(input_zero_point.to()); - const int32_t output_zp = - static_cast(output_zero_point.to()); - const bool is_per_channel = (weight_zero_point.numel() > 1); const int8_t* input_data = input.const_data_ptr(); const int8_t* weight_data = weights.const_data_ptr(); const int32_t* bias_data = bias.has_value() ? bias.value().const_data_ptr() : nullptr; + int32_t* kernel_sum_data = + kernel_sum.has_value() ? kernel_sum.value().data_ptr() : nullptr; int8_t* output_data = out.mutable_data_ptr(); - const int32_t* weight_zp_data = weight_zero_point.const_data_ptr(); - const int32_t* weight_mult_data = weight_multiplier.const_data_ptr(); - const int32_t* weight_shift_data = weight_shift.const_data_ptr(); - - if (!validate_per_channel_quant_params( - weight_mult_data, weight_shift_data, out_feat)) { - context.fail(Error::InvalidArgument); - return out; - } - - // Initialize scratch buffer context (validates early) - CMSISScratchBufferContext scratch_ctx( - const_cast(scratch_buffer), weights, weight_zero_point, bias); - scratch_ctx.compute_kernel_sums_if_needed(); - cmsis_nn_context ctx = scratch_ctx.get_cmsis_ctx(); + cmsis_nn_context ctx; + ctx.size = 0; // Not used in CMSIS-NN + ctx.buf = kernel_sum_data; // Setup CMSIS-NN parameters cmsis_nn_fc_params fc_params; - fc_params.input_offset = -input_zp; - fc_params.output_offset = output_zp; - fc_params.activation.min = std::numeric_limits::min(); - fc_params.activation.max = std::numeric_limits::max(); - - cmsis_nn_dims input_dims = {1, 1, 1, in_feat}; + fc_params.input_offset = static_cast(input_offset.to()); + fc_params.filter_offset = static_cast(filter_offset.to()); + fc_params.output_offset = static_cast(output_offset.to()); + fc_params.activation.min = static_cast(activation_min.to()); + fc_params.activation.max = static_cast(activation_max.to()); + + cmsis_nn_per_tensor_quant_params per_tensor_quant_params; + per_tensor_quant_params.multiplier = + static_cast(requantize_multipliers.at(0)); + per_tensor_quant_params.shift = static_cast(requantize_shifts.at(0)); + + auto in_feat = input.size(input.dim() - 1); + auto out_feat = out.size(out.dim() - 1); + auto batches = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + batches *= input.size(i); + } + ET_LOG( + Info, + "in features: %d, out_features: %d, batches: %d, kernel_sum_size: %d", + in_feat, + out_feat, + batches, + kernel_sum.has_value() ? kernel_sum.value().numel() : 0); + ET_LOG( + Info, + "kernel_sum[0]: %d, kernel_sum[1]: %d", + kernel_sum_data != nullptr ? kernel_sum_data[0] : -1, + kernel_sum_data != nullptr ? kernel_sum_data[1] : -1); + cmsis_nn_dims input_dims = {batches, 1, 1, in_feat}; cmsis_nn_dims filter_dims = {in_feat, 1, 1, out_feat}; cmsis_nn_dims bias_dims = {1, 1, 1, out_feat}; - cmsis_nn_dims output_dims = {1, 1, 1, out_feat}; - - arm_cmsis_nn_status status; - for (int32_t b = 0; b < batch_size; b++) { - const int8_t* batch_input = input_data + b * in_feat; - int8_t* batch_output = output_data + b * out_feat; - - ET_CHECK_MSG( - batch_input != nullptr && weight_data != nullptr, - "Null input pointers"); - ET_CHECK_MSG(in_feat > 0 && out_feat > 0, "Invalid dimensions"); - - if (is_per_channel) { - cmsis_nn_per_channel_quant_params per_channel_quant_params; - per_channel_quant_params.multiplier = - const_cast(weight_mult_data); - per_channel_quant_params.shift = const_cast(weight_shift_data); - - status = arm_fully_connected_per_channel_s8( - &ctx, - &fc_params, - &per_channel_quant_params, - &input_dims, - batch_input, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - batch_output); - } else { - fc_params.filter_offset = -weight_zp_data[0]; - cmsis_nn_per_tensor_quant_params per_tensor_quant_params; - per_tensor_quant_params.multiplier = weight_mult_data[0]; - per_tensor_quant_params.shift = weight_shift_data[0]; - - status = arm_fully_connected_s8( - &ctx, - &fc_params, - &per_tensor_quant_params, - &input_dims, - batch_input, - &filter_dims, - weight_data, - &bias_dims, - bias_data, - &output_dims, - batch_output); - } - - if (status != ARM_CMSIS_NN_SUCCESS) { - ET_LOG( - Error, - "quantized_linear_out: CMSIS-NN failed with status [%d]", - status); - context.fail(Error::Internal); - return out; - } + cmsis_nn_dims output_dims = {batches, 1, 1, out_feat}; + + arm_cmsis_nn_status status = arm_fully_connected_s8( + &ctx, + &fc_params, + &per_tensor_quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_linear_out: CMSIS-NN failed with status [%d]", + status); + context.fail(Error::Internal); + return out; } - return out; -} -// Functional variant (stub, not used at runtime) -Tensor quantized_linear( - KernelRuntimeContext& context, - const Tensor& input, - const Scalar& input_zero_point, - const Scalar& input_multiplier, - const Scalar& input_shift, - const Tensor& weights, - const Tensor& weight_zero_point, - const Tensor& weight_multiplier, - const Tensor& weight_shift, - const torch::executor::optional& bias, - const Tensor& bias_multiplier, - const Tensor& bias_shift, - const Tensor& scratch_buffer, - const Scalar& output_zero_point, - const Scalar& in_features, - const Scalar& out_features) { - ET_LOG(Info, "quantized_linear: called"); - assert(false); - return const_cast(input); + return out; } } // namespace native diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp new file mode 100644 index 00000000000..28af8406f87 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -0,0 +1,102 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { +namespace { + +constexpr int32_t kInt8ActivationMin = std::numeric_limits::min(); +constexpr int32_t kInt8ActivationMax = std::numeric_limits::max(); + +} // namespace + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +Tensor& quantized_mul_out( + KernelRuntimeContext& context, + const Tensor& input1_int8, + const Scalar& input1_zero_point, + const Tensor& input2_int8, + const Scalar& input2_zero_point, + const Scalar& output_zero_point, + const Scalar& output_multiplier, + const Scalar& output_shift, + Tensor& out) { + // Validate tensor types and quantization parameters + validate_cmsis_nn_tensor_requirements(input1_int8, input2_int8, out); + + const Scalar kIdentityMultiplier(/*value=*/1); + const Scalar kZeroShift(/*value=*/0); + validate_quantization_params( + input1_zero_point, + kIdentityMultiplier, + kZeroShift, + input2_zero_point, + kIdentityMultiplier, + kZeroShift, + output_zero_point, + output_multiplier, + output_shift, + out); + + // Extract quantization parameters + const int32_t zp1 = extractScalarToInt32(input1_zero_point); + const int32_t zp2 = extractScalarToInt32(input2_zero_point); + const int32_t out_zp = extractScalarToInt32(output_zero_point); + const int32_t output_mult = extractScalarToInt32(output_multiplier); + const int32_t output_shift_val = extractScalarToInt32(output_shift); + + // Note 1: The CMSIS-NN kernel implementation uses offsets which are always + // added to the data, whereas zero_points are subtracted when dequantizing + // (for the inputs) and added when quantizing (for the output). Hence the + // negative signs required here. + + // Note 2: The following rewrite is used + // yq = y / scale_out + zp_out + // y = x_1*x_2 + // x_i = scale_in_i * (xq_i - xq_i), i = 1, 2 + // ==> + // yq = (xq_1 - zp_in1) * (xq_2 - zp_in_2) * effective_scale + zp_out + // where + // effective_scale = (scale_in1 * scale_in2 / scale_out) + // Hence no input quantization params required here. + + // Call CMSIS-NN elementwise multiply kernel + arm_cmsis_nn_status status = arm_elementwise_mul_s8( + input1_int8.const_data_ptr(), + input2_int8.const_data_ptr(), + -static_cast(zp1), + -static_cast(zp2), + out.mutable_data_ptr(), + static_cast(out_zp), + output_mult, + output_shift_val, + kInt8ActivationMin, + kInt8ActivationMax, + static_cast(out.numel())); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_mul_out: arm_elementwise_mul_s8 failed with status [%d]", + status); + context.fail(Error::Internal); + return out; + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/op_transpose.cpp b/backends/cortex_m/ops/op_transpose.cpp new file mode 100644 index 00000000000..7befafc3791 --- /dev/null +++ b/backends/cortex_m/ops/op_transpose.cpp @@ -0,0 +1,124 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +#include +#include +#include + +// Include CMSIS-NN headers with C linkage +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { + +constexpr size_t kMaxSupportedDims = 4; + +} // namespace + +Tensor& transpose_out( + KernelRuntimeContext& context, + const Tensor& input, + const IntArrayRef perm, + Tensor& out) { + if (input.scalar_type() != ScalarType::Char || + out.scalar_type() != ScalarType::Char) { + ET_LOG( + Error, + "transpose_out: only int8 tensors are supported (input=%d, out=%d)", + static_cast(input.scalar_type()), + static_cast(out.scalar_type())); + context.fail(Error::InvalidArgument); + return out; + } + + const size_t rank = input.dim(); + if (rank == 0 || rank > kMaxSupportedDims) { + ET_LOG( + Error, + "transpose_out: expected tensor rank in [1, %zu], got %zu", + kMaxSupportedDims, + rank); + context.fail(Error::InvalidArgument); + return out; + } + + if (perm.size() != static_cast(rank)) { + ET_LOG( + Error, + "transpose_out: permutation length %zd does not match tensor rank %zu", + perm.size(), + rank); + context.fail(Error::InvalidArgument); + return out; + } + + std::array input_dims_arr{1, 1, 1, 1}; + std::array output_dims_arr{1, 1, 1, 1}; + for (size_t i = 0; i < rank; ++i) { + const auto in_size = input.size(i); + const auto out_size = out.size(i); + if (in_size > std::numeric_limits::max() || + out_size > std::numeric_limits::max()) { + ET_LOG( + Error, + "transpose_out: dimension size exceeds int32_t range (input=%lld, output=%lld)", + static_cast(in_size), + static_cast(out_size)); + context.fail(Error::InvalidArgument); + return out; + } + input_dims_arr[i] = static_cast(in_size); + output_dims_arr[i] = static_cast(out_size); + } + + cmsis_nn_dims input_dims = { + input_dims_arr[0], + input_dims_arr[1], + input_dims_arr[2], + input_dims_arr[3]}; + cmsis_nn_dims output_dims = { + output_dims_arr[0], + output_dims_arr[1], + output_dims_arr[2], + output_dims_arr[3]}; + + std::array perm_buffer{0, 1, 2, 3}; + for (size_t i = 0; i < rank; ++i) { + perm_buffer[i] = static_cast(perm[i]); + } + + const cmsis_nn_transpose_params transpose_params{ + static_cast(rank), perm_buffer.data()}; + + const int8_t* input_data = input.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + + const arm_cmsis_nn_status status = arm_transpose_s8( + input_data, output_data, &input_dims, &output_dims, &transpose_params); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "transpose_out: arm_transpose_s8 failed with status [%d]", + static_cast(status)); + context.fail(Error::Internal); + return out; + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 286f938ccc9..fe175ca9783 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -5,7 +5,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod +from typing import Sequence + import torch +import torch.nn.functional as F from executorch.backends.cortex_m.passes.passes_utils import ( requantize_cmsis, SHIFT_INT8, @@ -136,6 +140,10 @@ def quantized_add_meta( output_multiplier: int, output_shift: int, ) -> torch.Tensor: + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) @@ -154,6 +162,10 @@ def quantized_add_impl( output_multiplier: int, output_shift: int, ) -> torch.Tensor: + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) self_shifted = (self.to(torch.int32) - self_zero_point) << SHIFT_INT8 self_fp = requantize_cmsis(self_shifted, self_multiplier, self_shift) @@ -167,213 +179,394 @@ def quantized_add_impl( # =================================================================== -# QUANTIZED LINEAR OPERATION DEFINITION +# QUANTIZED MUL OPERATION DEFINITION # =================================================================== +lib.define( + "quantized_mul(" + "Tensor self, Scalar self_zero_point, " + "Tensor other, Scalar other_zero_point, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor" +) +lib.define( + "quantized_mul.out(" + "Tensor self, Scalar self_zero_point, " + "Tensor other, Scalar other_zero_point, " + "Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, " + "*, Tensor(a!) out) -> Tensor(a!)" +) -def _check_per_tensor_or_per_channel(param: torch.Tensor, out_channels: int, name: str): - assert param.numel() in [ - 1, - out_channels, - ], f"{name} must be per-tensor (1) or per-channel ({out_channels}), got {param.numel()}" +@register_fake("cortex_m::quantized_mul") +def quantized_mul_meta( + self: torch.Tensor, + self_zero_point: int, + other: torch.Tensor, + other_zero_point: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, +) -> torch.Tensor: + # Broadcast to output shape + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=torch.int8, device=self.device) +@impl(lib, "quantized_mul", "CompositeExplicitAutograd") +def quantized_mul_impl( + self: torch.Tensor, + self_zero_point: int, + other: torch.Tensor, + other_zero_point: int, + output_zero_point: int, + output_multiplier: int, + output_shift: int, +) -> torch.Tensor: + # CMSIS-NN kernel multiplies raw int8 tensors (after zero-point offset) and + # only uses the output multiplier/shift for rescaling. Mirror that here to + # keep the composite implementation numerically aligned with the backend. + assert self.shape == other.shape, ( + "Cortex-M quantized_mul: broadcasting is not yet supported — " + f"got self.shape={self.shape}, other.shape={other.shape}" + ) + self_int = self.to(torch.int32) - self_zero_point + other_int = other.to(torch.int32) - other_zero_point + result_fp = self_int * other_int + result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift) + result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8) + return result + + +# =================================================================== +# MINIMUM/MAXIMUM OPERATION DEFINITIONS +# =================================================================== +lib.define("minimum(Tensor self, Tensor other) -> Tensor") +lib.define("minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)") + + +@register_fake("cortex_m::minimum") +def minimum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + assert self.dtype == other.dtype, ( + "Cortex-M minimum: dtype mismatch — " + f"got self.dtype={self.dtype}, other.dtype={other.dtype}" + ) + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device) + + +@impl(lib, "minimum", "CompositeExplicitAutograd") +def minimum_impl(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.minimum(self, other) + + +lib.define("maximum(Tensor self, Tensor other) -> Tensor") +lib.define("maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)") + + +@register_fake("cortex_m::maximum") +def maximum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + assert self.dtype == other.dtype, ( + "Cortex-M maximum: dtype mismatch — " + f"got self.dtype={self.dtype}, other.dtype={other.dtype}" + ) + broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape) + return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device) + + +@impl(lib, "maximum", "CompositeExplicitAutograd") +def maximum_impl(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.maximum(self, other) + + +# =================================================================== +# QUANTIZED LINEAR OPERATION DEFINITION +# =================================================================== + lib.define( "quantized_linear.out(" - "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor input, " "Tensor weights, " - "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " - "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " - "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, " - "*, Tensor(a!) out) -> Tensor(a!)" + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" ) # Define functional variant (non-out version) lib.define( "quantized_linear(" - "Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, " + "Tensor input, " "Tensor weights, " - "Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, " - "Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, " - "Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features" + "Tensor? bias, " + "Tensor? kernel_sum, " + "Scalar input_offset, " + "Scalar filter_offset, " + "Scalar output_offset, " + "int[] requantize_multipliers, " + "int[] requantize_shifts, " + "Scalar activation_max, " + "Scalar activation_min" ") -> Tensor" ) -# Fake meta function for shape inference (out variant) -@register_fake("cortex_m::quantized_linear.out") -def quantized_linear_out_meta( +# Fake meta function for shape inference (functional variant) +@register_fake("cortex_m::quantized_linear") +def quantized_linear_meta( + input, + weights, + bias, + kernel_sum, + input_offset, + filter_offset, + output_offset, + requantize_multipliers, + requantize_shifts, + activation_max, + activation_min, +) -> torch.Tensor: + + shape = (*input.shape[:-1], weights.shape[0]) + return torch.empty(shape, dtype=input.dtype, device=input.device) + + +# Functional variant implementation +@impl(lib, "quantized_linear", "CompositeExplicitAutograd") +def quantized_linear_impl( input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, - out: torch.Tensor, + kernel_sum: torch.Tensor, + input_offset: int, + filter_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_max: int, + activation_min: int, ) -> torch.Tensor: - # Validate dimensions - batch_size = input.shape[0] - out_channels = weights.shape[0] + """ + Functional variant - creates output tensor and calls out variant + """ - # Validate weight quantization parameters dimensions - _check_per_tensor_or_per_channel( - weight_zero_point, out_channels, "weight_zero_point" - ) - _check_per_tensor_or_per_channel( - weight_multiplier, out_channels, "weight_multiplier" - ) - _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") + # Leaving both implementations for debugging purposes. + compute_using_kernel_sum = True - # Validate output shape - expected_shape = (batch_size, out_channels) - assert ( - out.shape == expected_shape - ), f"Output shape {out.shape} must be {expected_shape}" + if compute_using_kernel_sum: + weights_int32 = weights.to(torch.int32) - return out + input_int32 = input.to(torch.int32) + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) + lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset + output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + else: + weights_int32 = weights.to(torch.int32) + filter_offset -# Fake meta function for shape inference (functional variant) -@register_fake("cortex_m::quantized_linear") -def quantized_linear_meta( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, -) -> torch.Tensor: - # Validate dimensions (same as out variant) - batch_size = input.shape[0] - out_channels = weights.shape[0] + input_int32 = input.to(torch.int32) + input_offset + new_shape = (prod(input.shape[:-1]), input.shape[-1]) + input_reshaped = input_int32.reshape(new_shape) - # Validate weight quantization parameters dimensions - _check_per_tensor_or_per_channel( - weight_zero_point, out_channels, "weight_zero_point" - ) - _check_per_tensor_or_per_channel( - weight_multiplier, out_channels, "weight_multiplier" + output = torch.mm(input_reshaped, weights_int32.T) + if bias is not None: + output = output + bias + output_shape = (*input.shape[:-1], output.shape[-1]) + output_reshaped = output.reshape(output_shape) + + output = requantize_cmsis( + output_reshaped, requantize_multipliers[0], requantize_shifts[0] ) - _check_per_tensor_or_per_channel(weight_shift, out_channels, "weight_shift") + output += output_offset + output = torch.clamp(output, activation_min, activation_max).to(torch.int8) + return output + + +# =================================================================== +# TRANSPOSE OPERATION DEFINITION +# =================================================================== +lib.define("transpose(Tensor input, int[] perm) -> Tensor") +lib.define("transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)") - # Calculate output shape for functional variant - output_shape = (batch_size, out_channels) + +@register_fake("cortex_m::transpose") +def transpose_meta(input: torch.Tensor, perm) -> torch.Tensor: + output_shape = [input.shape[idx] for idx in perm] return torch.empty(output_shape, dtype=input.dtype, device=input.device) -@impl(lib, "quantized_linear.out", "CompositeExplicitAutograd") -def quantized_linear_out_impl( - input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, - *, - out: torch.Tensor, -) -> torch.Tensor: - """ - Fallback implementation for meta/testing - Note: This won't be called at runtime, only during compilation - """ +@impl(lib, "transpose", "CompositeExplicitAutograd") +def transpose_impl(input: torch.Tensor, perm) -> torch.Tensor: + return input.permute(tuple(perm)).contiguous() - # Per-channel dequantization - input_scale = input_multiplier * (2.0 ** (-input_shift)) - input_fp = (input.float() - input_zero_point) * input_scale - if weight_zero_point.numel() == 1: - # Per-tensor - weight_scale = weight_multiplier.item() * (2.0 ** (-weight_shift.item())) - weights_fp = (weights.float() - weight_zero_point.item()) * weight_scale - else: - # Per-channel - weight_scales = weight_multiplier.float() * (2.0 ** (-weight_shift.float())) - weights_fp = ( - weights.float() - weight_zero_point.float().unsqueeze(1) - ) * weight_scales.unsqueeze(1) - bias_fp = None - if bias is not None: - bias_scales = bias_multiplier.float() * (2.0 ** (-bias_shift.float())) - bias_fp = bias.float() * bias_scales - - result_fp = torch.nn.functional.linear(input_fp, weights_fp, bias_fp) - else: - result_fp = torch.nn.functional.linear(input_fp, weights_fp) - result_quantized = torch.clamp( - torch.round(result_fp + output_zero_point), -128, 127 - ).to(torch.int8) - out.copy_(result_quantized) - return out +# =================================================================== +# QUANTIZED CONV2D OPERATION DEFINITION +# =================================================================== -# Functional variant implementation -@impl(lib, "quantized_linear", "CompositeExplicitAutograd") -def quantized_linear_impl( +lib.define( + "quantized_conv2d(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max" + ") -> Tensor" +) + + +lib.define( + "quantized_conv2d.out(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" +) + + +def _compute_conv2d_output_shape( + input_shape: torch.Size, + weight_shape: torch.Size, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], +) -> torch.Size: + batch = input_shape[0] + in_height = input_shape[2] + in_width = input_shape[3] + # We store the weights in OHWI layout (out, kernel_h, kernel_w, in) + kernel_height = weight_shape[1] + kernel_width = weight_shape[2] + + stride_h, stride_w = stride + pad_h, pad_w = padding + dilation_h, dilation_w = dilation + + out_channels = weight_shape[0] + out_height = ( + in_height + 2 * pad_h - dilation_h * (kernel_height - 1) - 1 + ) // stride_h + 1 + out_width = ( + in_width + 2 * pad_w - dilation_w * (kernel_width - 1) - 1 + ) // stride_w + 1 + return torch.Size([batch, out_channels, out_height, out_width]) + + +@register_fake("cortex_m::quantized_conv2d") +def quantized_conv2d_meta( input: torch.Tensor, - input_zero_point: int, - input_multiplier: int, - input_shift: int, - weights: torch.Tensor, - weight_zero_point: torch.Tensor, - weight_multiplier: torch.Tensor, - weight_shift: torch.Tensor, - bias: torch.Tensor, - bias_multiplier: torch.Tensor, - bias_shift: torch.Tensor, - scratch_buffer: torch.Tensor, - output_zero_point: int, - in_features: int, - out_features: int, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, ) -> torch.Tensor: - """ - Functional variant - creates output tensor and calls out variant - """ - # Create output tensor - batch_size = input.shape[0] - output = torch.empty( - (batch_size, out_features), dtype=torch.int8, device=input.device + stride_vals = list(stride) + padding_vals = list(padding) + dilation_vals = list(dilation) + output_shape = _compute_conv2d_output_shape( + input.shape, weight.shape, stride_vals, padding_vals, dilation_vals ) - return quantized_linear_out_impl( - input, - input_zero_point, - input_multiplier, - input_shift, - weights, - weight_zero_point, - weight_multiplier, - weight_shift, - bias, - bias_multiplier, - bias_shift, - scratch_buffer, - output_zero_point, - in_features, - out_features, - out=output, + return torch.empty( + output_shape, + dtype=torch.int8, + device=input.device, + memory_format=torch.channels_last, + ) + + +@impl(lib, "quantized_conv2d", "CompositeExplicitAutograd") +def quantized_conv2d_impl( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + if input.dim() != 4 or weight.dim() != 4: + raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") + # Convert to int32 for accumulation and apply offsets + input_int32 = input.to(torch.int32) + int(input_offset) + weight_int32 = weight.to(torch.int32) + + if bias is None: + bias_int32 = torch.zeros( + weight.shape[0], dtype=torch.int32, device=input.device + ) + else: + bias_int32 = bias.to(torch.int32) + + input_channels = input.shape[1] + kernel_input_channels = weight.shape[3] + groups = input_channels // kernel_input_channels + + # Convert weights back to OIHW layout expected by torch.nn.functional.conv2d + weight_oi_hw = weight_int32.permute(0, 3, 1, 2).contiguous() + + conv_acc = F.conv2d( + input_int32, + weight_oi_hw, + bias_int32, + stride=tuple(stride), + padding=tuple(padding), + dilation=tuple(dilation), + groups=groups, ) + + result_channels = [] + for output_channel_i in range(conv_acc.shape[1]): + result_channel = requantize_cmsis( + conv_acc[:, output_channel_i, :, :], + int(requantize_multipliers[output_channel_i]), + int(requantize_shifts[output_channel_i]), + ) + result_channels.append(result_channel) + + result = torch.stack(result_channels, dim=1) + + result += output_offset + result = torch.clamp(result, activation_min, activation_max) + + return result.to(torch.int8) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 81ebeafc778..0b0b2f5c715 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -23,14 +23,38 @@ - arg_meta: null kernel_name: cortex_m::quantized_add_out -- func: cortex_m::quantized_linear(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features) -> Tensor +- func: cortex_m::quantized_mul.out(Tensor self, Scalar self_zero_point, Tensor other, Scalar other_zero_point, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null - kernel_name: cortex_m::quantized_linear + kernel_name: cortex_m::quantized_mul_out -- func: cortex_m::quantized_linear.out(Tensor input, Scalar input_zero_point, Scalar input_multiplier, Scalar input_shift, Tensor weights, Tensor weight_zero_point, Tensor weight_multiplier, Tensor weight_shift, Tensor? bias, Tensor bias_multiplier, Tensor bias_shift, Tensor scratch_buffer, Scalar output_zero_point, Scalar in_features, Scalar out_features, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::minimum_out + +- func: cortex_m::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::maximum_out + +- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_linear_out + +- func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::transpose_out + +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_conv2d_out diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 26456138cb2..5aeb60be514 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -3,7 +3,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .quantized_linear_fusion_pass import QuantizedLinearFusionPass # noqa +from .activation_fusion_pass import ActivationFusionPass # noqa +from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py new file mode 100644 index 00000000000..b200348cc9d --- /dev/null +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -0,0 +1,170 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import executorch.backends.cortex_m.ops.operators # noqa: F401 +from executorch.backends.arm._passes.quant_args import QuantArgs + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class ActivationFusionPass(ExportPass): + """Fuse activations into preceding Cortex-M quantized operators. + + Supported activation patterns: + q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq + + Fusing works by clamping the quantized output range (and zero-point when + required) of the preceding Cortex-M operator, then removing the activation + node from the graph. + """ + + TARGETS = { + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardsigmoid.default, + } + + FUSE_OPS = { + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.convolution.default, + } + + def _quantize(self, val, scale, zp, qmin, qmax): + return min(max(round(val / scale + zp), qmin), qmax) + + def _get_validated_qparams(self, node, input_node): + + if "input_qparams" not in input_node.meta or "output_qparams" not in node.meta: + logger.warning( + f"Cannot fuse activation for {input_node.name}->{node.name} as the pattern wasn't quantized properly." + ) + return None + + qparams_dict = node.meta["output_qparams"][0]._asdict() + zp = qparams_dict["zp"] + scale = qparams_dict["scale"] + qmin = qparams_dict["qmin"] + qmax = qparams_dict["qmax"] + + if not isinstance(scale, float) or not isinstance(zp, int): + logger.warning( + f"Cannot fuse activation {node.name} as quantization parameters are not per tensor." + ) + return None + + match node.target: + case exir_ops.edge.aten.relu.default: + quantized_min_val = self._quantize(0, scale, zp, qmin, qmax) + quantized_max_val = qmax + case exir_ops.edge.aten.hardtanh.default: + quantized_min_val = self._quantize(node.args[1], scale, zp, qmin, qmax) + quantized_max_val = self._quantize(node.args[2], scale, zp, qmin, qmax) + case exir_ops.edge.aten.hardsigmoid.default: + quantized_min_val = self._quantize(0, scale, zp, qmin, qmax) + quantized_max_val = self._quantize(1, scale, zp, qmin, qmax) + case _: + raise RuntimeError("Unexpected target {node.target}.") + + # If the minimal quantized value is larger than the qmin, it means that the quantized range contains + # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. + if qparams_dict["qmin"] != quantized_min_val: + logger.warning( + f"Cannot fuse activation {node.name} as qmin is out of range." + ) + return None + + # If the maximal quantized value is smaller than the qmax, it means that the quantized range contains + # invalid values [quantized_max_val + 1, ... , qmax], indicating bad quantization parameters. + if quantized_max_val != qparams_dict["qmax"]: + logger.warning( + f"Cannot fuse activation {node.name} as qmax is out of range." + ) + return None + + return qparams_dict + + def _update_qparams_hardsigmoid(self, quant_dict): + """ + Returns quant_dict with scale and zp updated to match hardsigmoid activation. + + The quantized output from the hard sigmoid is defined by + Q(y) = clamp(round(y/scale + zp), qmin, qmax) + y = clamp(x/6 + 1/2, 0, 1) + where x is the output of the fused activation op, conv or linear. + + Q(y) can be rewritten as a function of only x: + Q(y) = clamp(round(clamp(x/6 + 1/2, 0, 1)/scale + zp), qmin, qmax) + Q(y) = clamp(round(clamp((x/(6*scale) + 1/(2*scale) + zp, zp, 1/scale + zp)), qmin, qmax) + + From definition of the qparams mapping the output in the range [0,1] to quantized range + [qmin, qmax], we have: + zp = Q(0) <= qmin + 1/scale + zp = Q(1) >= qmax + which makes the inner clamp redundant. + + Therefore, hardsigmoid is equivalent to a quantization with modified parameters + new_scale := 6*scale + new_zp = zp + 1/(2*scale) ~= zp + round(1/(2*scale)) + """ + + new_scale = quant_dict["scale"] * 6 + + new_zp = quant_dict["zp"] + round(1 / (2 * quant_dict["scale"])) + clamped_new_zp = max(quant_dict["qmin"], min(quant_dict["qmax"], new_zp)) + + quant_dict["scale"] = new_scale + quant_dict["zp"] = clamped_new_zp + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + nodes_to_erase: list[Node] = [] + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + + input_node = node.args[0] + if ( + input_node.op != "call_function" + or input_node.target not in self.FUSE_OPS + ): + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} is not a supported fused activation op." + ) + continue + if len(input_node.users.values()) > 1: + logger.warning( + f"Cannot fuse activation {node.name} as input node {input_node.name} has multiple users." + ) + continue + + if (qparams_dict := self._get_validated_qparams(node, input_node)) is None: + continue + + if node.target == exir_ops.edge.aten.hardsigmoid.default: + self._update_qparams_hardsigmoid(qparams_dict) + + input_node.meta["output_qparams"][0] = QuantArgs(**qparams_dict) + + node.replace_all_uses_with(input_node) + nodes_to_erase.append(node) + modified = True + + for node in nodes_to_erase: + graph_module.graph.erase_node(node) + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py new file mode 100644 index 00000000000..721a1951753 --- /dev/null +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx +from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot + +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind +from torch.fx.passes.infra.pass_manager import PassResult + + +class ConvertToCortexMPass(XNNPACKPass): + """ + Cortex-M backend pass for replacing supported quantized kernels with Cortex-M + accelerated kernels. + + Used for ops which require changes to input tensors which is not supported + by call_operator. + """ + + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): + """ + Computes the precomputed kernel sum term (bias optional) + a * sum_j(wij + b) + ci + + for i = (1, ..., n), where j indexes the input activations. + """ + weights_transposed = weights.T + weights_int32 = weights_transposed.to(torch.int32) + offset_weights = weights_int32 + weight_offset + kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) + kernel_sum_offset = kernel_sum * input_offset + + if bias is not None: + kernel_sum_offset += bias + + return kernel_sum_offset + + def _get_linear_replacement(self, node): + """ + Let + - yi be the output activations (y1, ... yn) + - xj be the input activations (x1, ... xm) + - wij be the weights (w11, ... wnm) + - a be the input offset + - b be the weight offset + - ci be the bias + + Then the linear operation can be written as: + yi = sum_j((xj + a) * (wij + b)) + ci + = sum_j(xj*wij + xj*b + a*wij + a*b) + ci + = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) + = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum + + where kernel_sum is precomputed aot. + """ + input_scale = node.meta["input_qparams"][0].scale + input_zp = node.meta["input_qparams"][0].zp + weight_scale = node.meta["input_qparams"][1].scale + weight_zp = node.meta["input_qparams"][1].zp + output_scale = node.meta["output_qparams"][0].scale + output_zp = node.meta["output_qparams"][0].zp + output_min = node.meta["output_qparams"][0].qmin + output_max = node.meta["output_qparams"][0].qmax + + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + (input_scale * weight_scale) / output_scale + ) + + # TODO: Add support for configuring the backend to support other extensions. + # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, + # so this should be optional. + weights = node.args[1] + weights_tensor = get_param_tensor(self.exported_program, weights) + bias_tensor = ( + get_param_tensor(self.exported_program, node.args[2]) + if len(node.args) > 2 + else None + ) + kernel_sum_tensor = self._compute_kernel_sum( + weights_tensor, bias_tensor, -input_zp, -weight_zp + ) + with node.graph.inserting_after(weights): + kernel_sum = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_kernel_sum", + InputKind.PARAMETER, + kernel_sum_tensor, + ) + + args = ( + node.args[0], + weights, + None, + kernel_sum, + -input_zp, + -weight_zp, + output_zp, + [quantized_multiplier], + [quantized_shift], + output_max, + output_min, + ) + + return exir_ops.edge.cortex_m.quantized_linear.default, args + + def _get_convolution_replacement(self, node) -> int: + ( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = node.args + + # Extract values + input_scale = node.meta["input_qparams"][0].scale + input_zero_point = node.meta["input_qparams"][0].zp + weight_scales = node.meta["input_qparams"][1].scale + if not isinstance(weight_scales, list): + weight_scales = [weight_scales] * weight.data.shape[0] + + output_qparams = node.meta["output_qparams"][0] + output_scale = output_qparams.scale + output_zero_point = output_qparams.zp + output_qmin = output_qparams.qmin + output_qmax = output_qparams.qmax + + quantized_multipliers = [] + quantized_shifts = [] + for weight_scale in weight_scales: + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + input_scale * weight_scale / output_scale + ) + quantized_multipliers.append(quantized_multiplier) + quantized_shifts.append(quantized_shift) + + # Permute the weight tensor to the OHWI layout expected by CMSIS-NN. + weight_tensor = get_param_tensor(self.exported_program, weight) + weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( + memory_format=torch.channels_last + ) + + with node.graph.inserting_after(weight): + weight_nhwc = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_weight_nhwc", + InputKind.PARAMETER, + weight_permuted, + ) + + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + -input_zero_point, + output_zero_point, + torch.tensor(quantized_multipliers, dtype=torch.int32), + torch.tensor(quantized_shifts, dtype=torch.int32), + output_qmin, + output_qmax, + ) + return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if ( + node.meta.get("input_qparams", {}) == {} + or node.meta.get("output_qparams", {}) == {} + ): + continue + + match node.target: + case exir_ops.edge.aten.linear.default: + op, args = self._get_linear_replacement(node) + case exir_ops.edge.aten.convolution.default: + op, args = self._get_convolution_replacement(node) + case _: + continue + + with graph_module.graph.inserting_before(node): + cortex_m_op = graph_module.graph.create_node( + "call_function", + target=op, + args=args, + kwargs={}, + ) + + node.replace_all_uses_with(cortex_m_op) + graph_module.graph.erase_node(node) + + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 02429cc68e0..fd89986cef0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -4,26 +4,36 @@ # LICENSE file in the root directory of this source tree. -from executorch.backends.arm._passes import ScalarsToAttributePass +import inspect + +from executorch.backends.arm._passes import ( + FoldAndAnnotateQParamsPass, + ScalarsToAttributePass, +) from executorch.backends.cortex_m.passes import ( - QuantizedLinearFusionPass, + ActivationFusionPass, + ConvertToCortexMPass, QuantizedOpFusionPass, ReplaceQuantNodesPass, ) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) -from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_manager import PassManager +from executorch.exir.program._program import _transform +from torch.export import ExportedProgram -class CortexMPassManager(XNNPACKPassManager): +class CortexMPassManager(PassManager): pass_list: list[ExportPass] = [ + FoldAndAnnotateQParamsPass, ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, - QuantizedLinearFusionPass, + ActivationFusionPass, + ConvertToCortexMPass, ] pass_list_transform_for_annotation: list[ExportPass] = [ @@ -32,10 +42,29 @@ class CortexMPassManager(XNNPACKPassManager): ] def __init__(self, exported_program, passes=None): - super().__init__(exported_program, passes or self.pass_list) + self.exported_program = exported_program + if passes is not None: + self.passes = passes + else: + self.passes = self.pass_list def transform_for_annotation(self, model): passes = self.pass_list_transform_for_annotation for p in passes: model = p().call(model).graph_module return model + + def transform(self) -> ExportedProgram: + ep = self.exported_program + for pass_ in self.passes: + signature = inspect.signature(pass_.__init__) + if "exported_program" in signature.parameters: + transform_pass = pass_(ep) + elif issubclass(pass_, ExportPass): + transform_pass = pass_() + else: + raise RuntimeError( + f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" + ) + ep = _transform(ep, transform_pass) + return ep diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index b045005d34d..de07db2443a 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -50,14 +50,32 @@ def requantize_cmsis( multiplier: int, shift: int, ) -> torch.Tensor: - """ - Simulate CMSIS-NN fixed-point requantization: - result = round(tensor * multiplier / (2 ^ shift)) - with double rounding - """ - multiplied = torch.round(tensor.to(torch.int64) * multiplier) - shifted = torch.round(multiplied / (2 ** (31 - shift))) - return shifted.to(torch.int32) + """Simulate CMSIS-NN's arm_nn_requantize helper.""" + + tensor_64 = tensor.to(torch.int64) + left_shift = max(shift, 0) + right_shift = max(-shift, 0) + + # Equivalent to val * (1 << LEFT_SHIFT(shift)) + value = tensor_64 << left_shift + + # arm_nn_doubling_high_mult_no_sat(value, multiplier) + product = value * int(multiplier) + product = product + (1 << 30) + result = product >> 31 + + if right_shift: + remainder_mask = (1 << right_shift) - 1 + remainder = torch.bitwise_and(result, remainder_mask) + result = result >> right_shift + threshold = remainder_mask >> 1 + threshold_tensor = torch.full_like(result, threshold, dtype=torch.int64) + threshold_tensor = torch.where( + result < 0, threshold_tensor + 1, threshold_tensor + ) + result = result + torch.where(remainder > threshold_tensor, 1, 0) + + return result.to(torch.int32) def extract_scalar_value(node_arg) -> float: diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py deleted file mode 100644 index 11a49beb2f4..00000000000 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ /dev/null @@ -1,646 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import Optional - -import executorch.backends.cortex_m.ops.operators # noqa -import torch -import torch.fx - -from executorch.backends.cortex_m.passes.passes_utils import ( - cleanup_nodes, - is_dequant_node, - quantize_multiplier_aot, - transfer_metadata, -) - -from executorch.backends.transforms.utils import create_mutable_buffer, get_param_tensor - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.exir import ExportedProgram -from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx import Node -from torch.fx.passes.infra.pass_manager import PassResult - -logger = logging.getLogger("quantized_linear_fusion_pass") -logger.setLevel(logging.INFO) - - -class QuantizedLinearFusionPass(XNNPACKPass): - """ - Cortex-M backend pass that fuses quantized linear-like patterns. - Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize - Into: cortex_m.quantized_linear.default with direct parameters. - """ - - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.addmm.default: exir_ops.edge.cortex_m.quantized_linear.default, - exir_ops.edge.aten.mm.default: exir_ops.edge.cortex_m.quantized_linear.default, - } - - requires_exported_program = True - - def __init__(self, exported_program: ExportedProgram): - super().__init__(exported_program) - self.nodes_to_erase = [] - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - logger.info("Starting QuantizedLinearFusionPass") - assert id(self._exported_program.graph_module.graph) == id( - graph_module.graph - ), "QuantizedLinearFusionPass requires same graph instance" - - try: - fusion_count = self._fuse_quantized_linear_patterns(graph_module) - if fusion_count > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - logger.info(f"Linear fusion completed: {fusion_count} patterns fused") - return PassResult(graph_module, fusion_count > 0) - except Exception as e: - logger.error(f"Error in QuantizedLinearFusionPass: {e}") - raise e - - def _extract_linear_pattern(self, quantize_node: Node): - if not quantize_node.args: - return None - fc_node = quantize_node.args[0] - if not ( - fc_node.op == "call_function" - and fc_node.target in self.SUPPORTED_OPS_MAPPING - ): - return None - - op_name = str(fc_node.target).split(".")[-1] - - if "addmm" in str(fc_node.target): - input_dq_node = fc_node.args[1] - else: - input_dq_node = fc_node.args[0] - if not is_dequant_node(input_dq_node): - logger.info("input_dq_node is not a dequant node") - return None - weight_dq_node, bias_dq_node = self._extract_weight_bias_from_fc_op(fc_node) - if not weight_dq_node: - logger.info("No weight, bias dequantize node found") - return None - return ( - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - op_name, - ) - - def _extract_weight_bias_from_fc_op(self, fc_node: Node): - """Generic extraction for FC-like operations.""" - - if "addmm" in str(fc_node.target): - if len(fc_node.args) >= 3: - bias_arg = fc_node.args[0] - weight_arg = fc_node.args[2] - weight_dq_node = self._trace_to_dequantize(weight_arg) - logger.info( - f"weight_arg: {weight_arg}, traced weight_dq_node: {weight_dq_node}" - ) - - if weight_dq_node is None: - logger.info("No weight dequantize node found ") - - # For bias, try to trace to dequantize but allow None (no-bias case) - bias_dq_node = self._trace_to_dequantize(bias_arg) - if bias_dq_node is None: - logger.info("No bias dequantize node found - likely no-bias linear") - return weight_dq_node, bias_dq_node - elif any(op in str(fc_node.target) for op in ["linear", "mm"]): - if len(fc_node.args) >= 2: - weight_arg = fc_node.args[1] - bias_arg = fc_node.args[2] if len(fc_node.args) > 2 else None - weight_dq_node = self._trace_to_dequantize(weight_arg) - bias_dq_node = self._trace_to_dequantize(bias_arg) if bias_arg else None - return weight_dq_node, bias_dq_node - return None, None - - def _extract_input_quantization_parameters( - self, input_dq_node: Node - ) -> Optional[dict]: - """Extract input quantization parameters from dequantize node.""" - try: - # Find the quantize operation that produces the int8 tensor - input_quantize_node = None - if hasattr(input_dq_node, "args") and input_dq_node.args: - quantize_candidate = input_dq_node.args[0] - if getattr( - quantize_candidate, "op", None - ) == "call_function" and "quantize" in str( - getattr(quantize_candidate, "target", "") - ): - input_quantize_node = quantize_candidate - - if not input_quantize_node: - logger.error("Could not find quantize node for input!") - return None - - # Extract input quantization parameters - input_scale = self._extract_param_value(input_dq_node.args[1]) - input_zero_point = int(self._extract_param_value(input_dq_node.args[2])) - input_multiplier, input_shift = quantize_multiplier_aot(input_scale) - - return { - "input_scale": input_scale, - "input_zero_point": input_zero_point, - "input_multiplier": input_multiplier, - "input_shift": input_shift, - "input_tensor": input_quantize_node, - } - except Exception as e: - logger.error(f"Failed to extract input quantization parameters: {e}") - return None - - def _extract_output_quantization_parameters( - self, quantize_node: Node - ) -> Optional[dict]: - """Extract output quantization parameters from quantize node.""" - try: - output_scale = self._extract_param_value(quantize_node.args[1]) - output_zero_point = int(self._extract_param_value(quantize_node.args[2])) - - return { - "output_scale": output_scale, - "output_zero_point": output_zero_point, - } - except Exception as e: - logger.error(f"Failed to extract output quantization parameters: {e}") - return None - - def _create_constant_parameter_buffer( - self, graph, quantize_node: Node, data: torch.Tensor, name: str - ): - """Create a parameter buffer""" - buffer_name = f"{name}_{id(quantize_node)}" - - setattr(graph.owning_module, buffer_name, data) - - # Create a get_attr node - with graph.inserting_before(quantize_node): - buffer_node = graph.create_node( - op="get_attr", target=buffer_name, name=buffer_name - ) - - # Set metadata - buffer_node.meta["val"] = data - - return buffer_node - - def _extract_weight_parameters(self, weight_dq_node: Node) -> Optional[dict]: - try: - weight_tensor = weight_dq_node.args[0] - weight_scale = weight_dq_node.args[1] - weight_zero_point = ( - weight_dq_node.args[2] if len(weight_dq_node.args) > 2 else None - ) - - weight_scale_data = self._extract_param_value(weight_scale) - weight_zp_data = ( - self._extract_param_value(weight_zero_point) - if weight_zero_point - else None - ) - - # Get actual tensor data to determine output features - weight_tensor_data = get_param_tensor(self._exported_program, weight_tensor) - out_features = weight_tensor_data.shape[0] - - # Handle both per-tensor and per-channel - if ( - isinstance(weight_scale_data, torch.Tensor) - and weight_scale_data.numel() > 1 - ): - # Per-channel: ensure we have the right number of elements - assert ( - weight_scale_data.numel() == out_features - ), f"Scale size {weight_scale_data.numel()} != out_features {out_features}" - - multipliers = [] - shifts = [] - for scale in weight_scale_data: - mult, shift = quantize_multiplier_aot(scale.item()) - multipliers.append(mult) - shifts.append(shift) - - weight_multiplier = torch.tensor(multipliers, dtype=torch.int32) - weight_shift = torch.tensor(shifts, dtype=torch.int32) - weight_zp_tensor = ( - weight_zp_data.int() - if weight_zp_data is not None - else torch.zeros(out_features, dtype=torch.int32) - ) - else: - # Per-tensor: create tensors with correct size for output features - scale_val = ( - weight_scale_data.item() - if isinstance(weight_scale_data, torch.Tensor) - else weight_scale_data - ) - mult, shift = quantize_multiplier_aot(scale_val) - - # Create tensors sized for out_features (not single element) - weight_multiplier = torch.full((out_features,), mult, dtype=torch.int32) - weight_shift = torch.full((out_features,), shift, dtype=torch.int32) - weight_zp_tensor = torch.full( - (out_features,), - weight_zp_data if weight_zp_data else 0, - dtype=torch.int32, - ) - - # Validate multipliers - for i, mult in enumerate(weight_multiplier): - if mult < (1 << 30) or mult > ((1 << 31) - 1): - logger.error( - f"Invalid multiplier[{i}]: {mult}, scale was: {weight_scale_data}" - ) - return None - - return { - "weight_tensor": weight_tensor, - "weight_zero_point_data": weight_zp_tensor, - "weight_multiplier_data": weight_multiplier, - "weight_shift_data": weight_shift, - } - except Exception as e: - logger.error(f"Failed to extract weight parameters: {e}") - return None - - def _extract_bias_parameters(self, bias_dq_node: Optional[Node]) -> Optional[dict]: - """ - Extract bias parameters for quantized linear fusion. - Handles both dequantized bias nodes and constant bias tensors. - Returns a dict with bias_tensor, bias_multiplier, and bias_shift. - """ - if not bias_dq_node: - # No bias present - return None - try: - # Case 1: Bias is a dequantize node - if hasattr(bias_dq_node, "op") and is_dequant_node(bias_dq_node): - bias_tensor = bias_dq_node.args[0] - bias_scale = bias_dq_node.args[1] - - bias_scale_data = self._extract_param_value(bias_scale) - - if ( - isinstance(bias_scale_data, torch.Tensor) - and bias_scale_data.numel() > 1 - ): - # Per-channel bias - bias_multipliers = [] - bias_shifts = [] - for scale_val in bias_scale_data.tolist(): - mult, shift = quantize_multiplier_aot(scale_val) - bias_multipliers.append(mult) - bias_shifts.append(shift) - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multipliers, - "bias_shift": bias_shifts, - } - else: - # Per-tensor bias - bias_scale_val = ( - bias_scale_data.item() - if isinstance(bias_scale_data, torch.Tensor) - else bias_scale_data - ) - bias_multiplier, bias_shift = quantize_multiplier_aot( - bias_scale_val - ) - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multiplier, - "bias_shift": bias_shift, - } - else: - # Case 2: Bias is a constant tensor (not dequantized) - # This can happen if bias is not quantized in the model - bias_tensor = bias_dq_node - # Use default multiplier/shift for unquantized bias - bias_multiplier = 1 - bias_shift = 0 - return { - "bias_tensor": bias_tensor, - "bias_multiplier": bias_multiplier, - "bias_shift": bias_shift, - } - except Exception as e: - logger.error(f"Failed to extract bias parameters: {e}") - return None - - def _prepare_bias_tensors( - self, bias_params: Optional[dict], out_features: int - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Prepare bias multiplier and shift tensors for kernel call. - Returns (bias_multiplier_tensor, bias_shift_tensor) both sized [out_features]. - """ - if bias_params: - bias_multiplier = bias_params["bias_multiplier"] - bias_shift = bias_params["bias_shift"] - - # Convert to tensors of the right size - if isinstance(bias_multiplier, int): - bias_multiplier_tensor = torch.full( - [out_features], bias_multiplier, dtype=torch.int32 - ) - elif isinstance(bias_multiplier, list): - assert ( - len(bias_multiplier) == out_features - ), f"Bias multiplier size {len(bias_multiplier)} != out_features {out_features}" - bias_multiplier_tensor = torch.tensor( - bias_multiplier, dtype=torch.int32 - ) - elif isinstance(bias_multiplier, torch.Tensor): - assert ( - bias_multiplier.numel() == out_features - ), f"Bias multiplier size {bias_multiplier.numel()} != out_features {out_features}" - bias_multiplier_tensor = bias_multiplier - else: - raise TypeError( - f"Unsupported bias_multiplier type: {type(bias_multiplier)}" - ) - - if isinstance(bias_shift, int): - bias_shift_tensor = torch.full( - [out_features], bias_shift, dtype=torch.int32 - ) - elif isinstance(bias_shift, list): - assert ( - len(bias_shift) == out_features - ), f"Bias shift size {len(bias_shift)} != out_features {out_features}" - bias_shift_tensor = torch.tensor(bias_shift, dtype=torch.int32) - elif isinstance(bias_shift, torch.Tensor): - assert ( - bias_shift.numel() == out_features - ), f"Bias shift size {bias_shift.numel()} != out_features {out_features}" - bias_shift_tensor = bias_shift - else: - raise TypeError(f"Unsupported bias_shift type: {type(bias_shift)}") - - return bias_multiplier_tensor, bias_shift_tensor - else: - # No bias: return zero tensors of correct shape - return ( - torch.zeros([out_features], dtype=torch.int32), - torch.zeros([out_features], dtype=torch.int32), - ) - - def _extract_param_value(self, node_or_value): - """ - Extract a scalar value from a Node or a direct float/int. - """ - if isinstance(node_or_value, (float, int)): - return node_or_value - # If it's a tensor, get its scalar value if possible - if isinstance(node_or_value, torch.Tensor): - return node_or_value.item() if node_or_value.numel() == 1 else node_or_value - # If it's a Node, use get_param_tensor - if hasattr(node_or_value, "op"): - tensor = get_param_tensor(self._exported_program, node_or_value) - return tensor.item() if tensor.numel() == 1 else tensor - raise TypeError(f"Unsupported parameter type: {type(node_or_value)}") - - def _calculate_cmsis_scratch_size(self, weight_tensor) -> int: - """Calculate CMSIS-NN scratch buffer size for quantized linear operations. - - Source: CMSIS-NN arm_fully_connected_s8_get_buffer_size() returns filter_dims->w * sizeof(int32_t). - This buffer stores pre-computed kernel sums (weight row sums) - one int32_t per output feature. - Same buffer size applies to both per-tensor and per-channel quantization paths since both use - identical kernel sum optimization in the underlying matrix multiplication. - """ - try: - print(f"weight_tensor type: {type(weight_tensor)}, value: {weight_tensor}") - weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape - out_features = weight_shape[0] # filter_dims->w in CMSIS terms - - # CMSIS-NN implementation expects the following size - cmsis_buffer_size = out_features * 4 # sizeof(int32_t) - return cmsis_buffer_size - except Exception as e: - logger.error(f"Failed to calculate CMSIS scratch size: {e}") - return 2048 # Fallback - - def _create_scratch_buffer(self, graph, quantize_node: Node, weight_tensor): - cmsis_scratch = self._calculate_cmsis_scratch_size(weight_tensor) - - kernel_sum_header = 8 # sizeof(KernelSumHeader) - total_size = kernel_sum_header + cmsis_scratch - - logger.info( - f"Kernel sum header: {kernel_sum_header}, CMSIS buffer: {cmsis_scratch}, total: {total_size}" - ) - - return create_mutable_buffer( - self._exported_program, - name=f"b_cmsis_linear_scratch_{id(quantize_node)}", - data=torch.zeros((total_size,), dtype=torch.int8), - ) - - def _create_fused_node( - self, - graph, - quantize_node: Node, - quant_params: dict, - weight_params: dict, - bias_params: Optional[dict], - quantized_target, - ) -> Node: - """Generic fused node creation for any FC-like operation.""" - # Extract all parameters - input_tensor = quant_params["input_tensor"] - input_zp = quant_params["input_zero_point"] - input_multiplier = quant_params["input_multiplier"] - input_shift = quant_params["input_shift"] - weight_tensor = weight_params["weight_tensor"] - - weight_zp_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_zero_point_data"], "weight_zp" - ) - weight_mult_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_multiplier_data"], "weight_mult" - ) - weight_shift_node = self._create_constant_parameter_buffer( - graph, quantize_node, weight_params["weight_shift_data"], "weight_shift" - ) - # Get dimensions - weight_shape = get_param_tensor(self._exported_program, weight_tensor).shape - assert ( - len(weight_shape) == 2 - ), f"Weight tensor must be 2D, got shape {weight_shape}" - in_features = weight_shape[1] - out_features = weight_shape[0] - - # Handle bias - bias_tensor = bias_params["bias_tensor"] if bias_params else None - bias_multiplier, bias_shift = self._prepare_bias_tensors( - bias_params, out_features - ) - output_zp = quant_params["output_zero_point"] - - scratch_buffer = self._create_scratch_buffer( - graph, quantize_node, weight_tensor - ) - - with graph.inserting_after(quantize_node): - fused = graph.create_node( - "call_function", - target=quantized_target, - args=( - input_tensor, - input_zp, - input_multiplier, - input_shift, - weight_tensor, - weight_zp_node, - weight_mult_node, - weight_shift_node, - bias_tensor, - bias_multiplier, - bias_shift, - scratch_buffer, - output_zp, - in_features, - out_features, - ), - kwargs={}, - ) - - transfer_metadata(fused, quantize_node, "QuantizedLinearFusionPass") - return fused - - def _mark_for_cleanup(self, nodes): - for node in nodes: - if node is not None: - self.nodes_to_erase.append(node) - - def _cleanup_nodes(self, graph): - cleanup_nodes(self.nodes_to_erase, graph) - self.nodes_to_erase.clear() - - def _extract_linear_pattern_with_validation(self, quantize_node: Node): - pattern_info = self._extract_linear_pattern(quantize_node) - if not pattern_info: - return None - # Optionally add more validation here if needed - return pattern_info - - def _trace_to_dequantize(self, node: Optional[Node], max_depth=3) -> Optional[Node]: - """Trace through transformations to find dequantize node.""" - current_node = node - depth = 0 - while current_node and depth < max_depth: - if is_dequant_node(current_node): - return current_node - if current_node.op == "call_function" and current_node.target in { - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.view_copy.default, - }: - if current_node.args: - current_node = current_node.args[0] - depth += 1 - continue - break - return None - - def _fuse_quantized_linear_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - fusion_count = 0 - graph = graph_module.graph - for node in list(graph.nodes): - if not ( - node.op == "call_function" and "quantize_per_tensor" in str(node.target) - ): - continue - pattern_info = self._extract_linear_pattern_with_validation(node) - if not pattern_info: - continue - - ( - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - op_name, - ) = pattern_info - - # Get quantized target for this FC operation - quantized_target = self.SUPPORTED_OPS_MAPPING.get(fc_node.target) - if not quantized_target: - logger.warning(f"No quantized target found for {fc_node.target}") - continue - - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") - - try: - input_params = self._extract_input_quantization_parameters( - input_dq_node - ) - if not input_params: - logger.error( - "Quantization parameter extraction failed for node: %s", node - ) - return None - output_params = self._extract_output_quantization_parameters( - quantize_node - ) - if not output_params: - logger.error( - "Output quantization parameter extraction failed for node: %s", - node, - ) - return None - quant_params = {**input_params, **output_params} - logger.info(f"Quantization parameters: {quant_params}") - - weight_params = self._extract_weight_parameters(weight_dq_node) - if not weight_params: - continue - bias_params = self._extract_bias_parameters(bias_dq_node) - if bias_dq_node and not bias_params: - continue - fused_node = self._create_fused_node( - graph, - quantize_node, - quant_params, - weight_params, - bias_params, - quantized_target, - ) - logger.info(f"Created fused {op_name} node: {fused_node}") - - quantize_node.replace_all_uses_with(fused_node) - self._mark_for_cleanup( - [ - quantize_node, - fc_node, - input_dq_node, - weight_dq_node, - bias_dq_node, - ] - ) - fusion_count += 1 - logger.info(f"✅ Successfully fused {op_name} operation {fusion_count}") - except Exception as e: - logger.error( - f"Failed to fuse {op_name} pattern for {fc_node.name}: {e}" - ) - continue - self._cleanup_nodes(graph) - return fusion_count diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py index 888155dcfd0..c84e66dd7d9 100644 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ b/backends/cortex_m/passes/quantized_op_fusion_pass.py @@ -5,23 +5,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging -from typing import Set +from typing import Dict -import executorch.backends.cortex_m.ops.operators # noqa import torch from executorch.backends.cortex_m.passes.passes_utils import ( - extract_scalar_value, quantize_multiplier_aot, SHIFT_INT8, ) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass -from torch.fx.passes.infra.pass_manager import PassResult -logger = logging.getLogger("quant_op_fusion_pass") -logger.setLevel(logging.INFO) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from torch.fx.node import Argument class QuantizedOpFusionPass(ExportPass): @@ -35,234 +31,117 @@ class QuantizedOpFusionPass(ExportPass): Supports multiple binary operations with backward compatibility for add. """ - # Generic operation mapping - SUPPORTED_OPS_MAPPING = { - exir_ops.edge.aten.add.Tensor: exir_ops.edge.cortex_m.quantized_add.default, - # Future binary ops to be added here: - } - - def __init__(self): - super().__init__() - - def _get_dequant_targets(self) -> Set: - """Support both decomposed and cortex_m dequant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.cortex_m.dequantize_per_tensor.default, - } - - def _get_quant_targets(self) -> Set: - """Support both decomposed and cortex_m quant targets for flexible pass ordering.""" - return { - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - def _is_supported_binary_op(self, node: torch.fx.Node) -> bool: - """Check if node is a supported binary operation.""" - is_supported = ( - node.op == "call_function" and node.target in self.SUPPORTED_OPS_MAPPING + def _get_add_replacement(self, args, meta): + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return exir_ops.edge.aten.add.Tensor, args + + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp + + # AoT COMPUTATION: Calculate multipliers and shifts + max_scale_2x = 2 * max(scale1, scale2) + + input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) + input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) + output_mult, output_shift = quantize_multiplier_aot( + max_scale_2x / (output_scale * (1 << SHIFT_INT8)) ) - if not is_supported: - return False - - shape1 = node.args[0].meta["val"].shape - shape2 = node.args[1].meta["val"].shape - is_broadcast = shape1 != shape2 - return not is_broadcast - def _is_dequant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a dequantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_dequant_targets() + args = ( + args[0], + zero_point1, + input1_mult, + input1_shift, + args[1], + zero_point2, + input2_mult, + input2_shift, + output_zero_point, + output_mult, + output_shift, ) - def _is_quant_node(self, node: torch.fx.Node) -> bool: - """Check if node is a quantize operation.""" - return ( - hasattr(node, "op") - and node.op == "call_function" - and node.target in self._get_quant_targets() + return exir_ops.edge.cortex_m.quantized_add.default, args + + def _get_mul_replacement(self, args, meta): + if ( + meta.data.get("input_qparams", {}) == {} + or meta.data.get("output_qparams", {}) == {} + ): + return exir_ops.edge.aten.mul.Tensor, args + + # Extract values + scale1 = meta["input_qparams"][0].scale + zero_point1 = meta["input_qparams"][0].zp + scale2 = meta["input_qparams"][1].scale + zero_point2 = meta["input_qparams"][1].zp + output_scale = meta["output_qparams"][0].scale + output_zero_point = meta["output_qparams"][0].zp + + scale_factor = (scale1 * scale2) / output_scale + output_mult, output_shift = quantize_multiplier_aot(scale_factor) + + args = ( + args[0], + zero_point1, + args[1], + zero_point2, + output_zero_point, + output_mult, + output_shift, ) - def _transfer_metadata( - self, - new_node: torch.fx.Node, - source_node: torch.fx.Node, - pass_name: str = "QuantizedOpFusionPass", - ) -> None: - """Metadata transfer with proper provenance tracking.""" - if hasattr(source_node, "meta") and source_node.meta: - new_node.meta = source_node.meta.copy() - - if "from_node" in new_node.meta: - from_node_list = new_node.meta.get("from_node", []).copy() - from_node_list.append( - {"source": source_node.name, "pass": pass_name, "op": "fuse"} - ) - new_node.meta["from_node"] = from_node_list - - # Copy essential fields - for field in ["tensor_meta", "stack_trace"]: - if field in source_node.meta: - new_node.meta[field] = source_node.meta[field] - - def _normalize_to_cortex_m_targets(self, graph_module: torch.fx.GraphModule) -> int: - """Convert decomposed targets to cortex_m equivalents for consistent handling.""" - target_mapping = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.cortex_m.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.cortex_m.quantize_per_tensor.default, - } - - normalization_count = 0 - for node in list(graph_module.graph.nodes): - if node.op == "call_function" and node.target in target_mapping: - logger.info(f"Normalizing {node.target} to cortex_m equivalent") - node.target = target_mapping[node.target] - normalization_count += 1 - - return normalization_count - - def _fuse_quantized_binary_patterns( - self, graph_module: torch.fx.GraphModule - ) -> int: - """Generic fusion for quantized binary operation patterns.""" - fusion_count = 0 - nodes_to_erase = [] - - for node in list(graph_module.graph.nodes): - if not self._is_quant_node(node): - continue - - quantize_node = node - if not quantize_node.args: - continue - - binary_op_node = quantize_node.args[0] - if not self._is_supported_binary_op(binary_op_node): - continue + return exir_ops.edge.cortex_m.quantized_mul.default, args - if len(binary_op_node.args) < 2: - continue + def _get_minimum_replacement(self, args, meta): + if args[0].data.dtype != torch.int8: + return exir_ops.edge.aten.minimum.default, args - dequant_node1, dequant_node2 = binary_op_node.args[:2] - if not ( - self._is_dequant_node(dequant_node1) - and self._is_dequant_node(dequant_node2) - ): - continue + return exir_ops.edge.cortex_m.minimum.default, args - # Get the target quantized operation - quantized_target = self.SUPPORTED_OPS_MAPPING[binary_op_node.target] - # Extract op name (e.g., 'Tensor' -> 'add') - op_name = str(binary_op_node.target).split(".")[-1] - logger.info(f"✅ Found complete cortex_m Q/DQ + {op_name} pattern!") + def _get_maximum_replacement(self, args, meta): + if args[0].data.dtype != torch.int8: + return exir_ops.edge.aten.maximum.default, args - try: - # Extract values - int8_tensor1, scale1, zero_point1 = dequant_node1.args[:3] - int8_tensor2, scale2, zero_point2 = dequant_node2.args[:3] - output_scale, output_zero_point = quantize_node.args[1:3] + return exir_ops.edge.cortex_m.maximum.default, args - # Convert to Python floats - scale1_val = extract_scalar_value(scale1) - scale2_val = extract_scalar_value(scale2) - output_scale_val = extract_scalar_value(output_scale) - zp1_val = int(extract_scalar_value(zero_point1)) - zp2_val = int(extract_scalar_value(zero_point2)) - output_zp_val = int(extract_scalar_value(output_zero_point)) + def _get_permute_replacement(self, args, meta): + if args[0].data.dtype != torch.int8: + return exir_ops.edge.aten.permute_copy.default, args - max_scale_2x = 2 * max(scale1_val, scale2_val) - # AoT COMPUTATION: Calculate multipliers and shifts + rank = len(args[0].data.shape) + perms = [p % rank for p in args[1]] + args = (args[0], perms) + return exir_ops.edge.cortex_m.transpose.default, args - input1_mult, input1_shift = quantize_multiplier_aot( - scale1_val / max_scale_2x - ) - input2_mult, input2_shift = quantize_multiplier_aot( - scale2_val / max_scale_2x - ) - output_mult, output_shift = quantize_multiplier_aot( - max_scale_2x / (output_scale_val * (1 << SHIFT_INT8)) - ) - - logger.info("AoT computed parameters:") - logger.info(f" Input1: mult={input1_mult}, shift={input1_shift}") - logger.info(f" Input2: mult={input2_mult}, shift={input2_shift}") - logger.info(f" Output: mult={output_mult}, shift={output_shift}") - - with graph_module.graph.inserting_after(quantize_node): - fused = graph_module.graph.create_node( - "call_function", - target=quantized_target, - args=( - int8_tensor1, - zp1_val, - input1_mult, - input1_shift, - int8_tensor2, - zp2_val, - input2_mult, - input2_shift, - output_zp_val, - output_mult, - output_shift, - ), - kwargs={}, - ) - - # metadata transfer - self._transfer_metadata(fused, quantize_node) - - logger.info(f"✅ Created fused quantized_{op_name} node: {fused}") - - # Replace all uses - quantize_node.replace_all_uses_with(fused) - binary_op_node.replace_all_uses_with(fused) - dequant_node1.replace_all_uses_with(fused) - dequant_node2.replace_all_uses_with(fused) - - nodes_to_erase.extend( - [quantize_node, binary_op_node, dequant_node1, dequant_node2] - ) - fusion_count += 1 - logger.info(f"Pattern fused, total so far: {fusion_count}") - - except Exception as e: - logger.info(f"❌ Error during AoT computation: {e}") - logger.info(" Skipping fusion for this pattern") - continue - - for old_node in reversed(nodes_to_erase): - if old_node in graph_module.graph.nodes and len(old_node.users) == 0: - logger.info(f"🗑️ Erasing node: {old_node}") - graph_module.graph.erase_node(old_node) - - return fusion_count - - def call(self, graph_module: torch.fx.GraphModule): - logger.info("QuantizedOpFusionPass.call() started") - - # Normalize targets for flexible pass ordering - normalization_count = self._normalize_to_cortex_m_targets(graph_module) - - # Generic fusion for supported binary operations - fusion_count = self._fuse_quantized_binary_patterns(graph_module) - - total_changes = normalization_count + fusion_count - logger.info(f"Total changes: {total_changes}") - - if total_changes > 0: - graph_module.graph.eliminate_dead_code() - graph_module.graph.lint() - graph_module.recompile() - - logger.debug("=== AFTER FUSION: All nodes in the graph ===") - for i, node in enumerate(graph_module.graph.nodes): - logger.debug(f"Node {i}: op={node.op}, target={node.target}") - if "quantized_" in str(node.target) and "add" in str(node.target): - logger.debug(" ⭐ FOUND QUANTIZED BINARY OP NODE! ⭐") - logger.debug("=== END DEBUG ===") - - return PassResult(graph_module, total_changes > 0) + def call_operator( + self, + op: EdgeOpOverload, + args: tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + + match op: + case exir_ops.edge.aten.add.Tensor: + op, args = self._get_add_replacement(args, meta) + case exir_ops.edge.aten.mul.Tensor: + op, args = self._get_mul_replacement(args, meta) + case exir_ops.edge.aten.minimum.default: + op, args = self._get_minimum_replacement(args, meta) + case exir_ops.edge.aten.maximum.default: + op, args = self._get_maximum_replacement(args, meta) + case exir_ops.edge.aten.permute_copy.default: + op, args = self._get_permute_replacement(args, meta) + case _: + pass + + return super().call_operator(op, args, {}, meta) diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py index 6ffc011df27..25d3626a147 100644 --- a/backends/cortex_m/quantizer/operator_configs.py +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -10,6 +10,7 @@ import torch from executorch.backends.cortex_m.quantizer.quantization_configs import ( + INT8_PER_CHANNEL_CONFIG, INT8_PER_TENSOR_CONFIG, ) from torchao.quantization.pt2e.quantizer import OperatorConfig @@ -17,11 +18,27 @@ # ----------------- OPERATOR PATTERN PRESETS ----------------- BINARY_OP_PATTERNS = [ [torch.ops.aten.add.Tensor], + [torch.ops.aten.mul.Tensor], ] LINEAR_OP_PATTERNS = [ [torch.ops.aten.linear.default], [torch.ops.aten.linear.default, torch.ops.aten.relu.default], + [torch.ops.aten.linear.default, torch.ops.aten.relu_.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardtanh.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardtanh_.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid.default], + [torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid_.default], +] + +CONV_OP_PATTERNS = [ + [torch.ops.aten.conv2d.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.relu.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh_.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid.default], + [torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid_.default], ] # ----------------- OPERATOR CONFIG PRESETS ----------------- @@ -33,3 +50,8 @@ INT8_PER_TENSOR_CONFIG, LINEAR_OP_PATTERNS, ) + +INT8_CONV_OPERATOR_CONFIG = OperatorConfig( + INT8_PER_CHANNEL_CONFIG, + CONV_OP_PATTERNS, +) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index 7f43a89daad..c6600241b6d 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -5,7 +5,11 @@ import torch -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e import ( + HistogramObserver, + MinMaxObserver, + PerChannelMinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationConfig, @@ -21,8 +25,9 @@ INT8_WEIGHT_PER_CHANNEL_QSPEC = QuantizationSpec( dtype=torch.int8, - observer_or_fake_quant_ctr=MinMaxObserver, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, qscheme=torch.per_channel_symmetric, + ch_axis=0, ) INT8_ACTIVATION_PER_TENSOR_QSPEC = QuantizationSpec( @@ -33,8 +38,9 @@ INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec( dtype=torch.int8, - observer_or_fake_quant_ctr=HistogramObserver, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, qscheme=torch.per_channel_affine, + ch_axis=0, ) @@ -61,7 +67,18 @@ def _get_int32_bias_qspec(node): dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=torch.per_tensor_symmetric, + ) + + +def _get_int32_per_channel_bias_qspec(node): + return DerivedQuantizationSpec( + derived_from=[(node.args[0], node), (node.args[1], node)], # type: ignore[list-item] + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, ) @@ -75,8 +92,8 @@ def _get_int32_bias_qspec(node): INT8_PER_CHANNEL_CONFIG = QuantizationConfig( - INT8_ACTIVATION_PER_CHANNEL_QSPEC, - INT8_ACTIVATION_PER_CHANNEL_QSPEC, + INT8_ACTIVATION_PER_TENSOR_QSPEC, + INT8_ACTIVATION_PER_TENSOR_QSPEC, INT8_WEIGHT_PER_CHANNEL_QSPEC, - _get_int32_bias_qspec, + _get_int32_per_channel_bias_qspec, ) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index d75fa45ed1e..8bfc32049ed 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -7,21 +7,26 @@ from typing import Callable, List, Optional import torch - from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor - from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.operator_configs import ( + BINARY_OP_PATTERNS, + CONV_OP_PATTERNS, INT8_BINARY_OPS_OPERATOR_CONFIG, + INT8_CONV_OPERATOR_CONFIG, INT8_LINEAR_OPERATOR_CONFIG, ) +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + INT8_PER_TENSOR_CONFIG, +) from torch._ops import OpOverload from torch.fx import GraphModule, Node from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, QuantizationAnnotation, Quantizer, + SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY @@ -34,7 +39,7 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool: """ if node is None: return False - if node.target not in [torch.ops.aten.add.Tensor]: + if [node.target] not in BINARY_OP_PATTERNS: return False if len(node.all_input_nodes) == 2: @@ -44,12 +49,33 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool: return False + def nchw_filter(self, node: Optional[Node]) -> bool: + """ + Filter function to exclude nodes that use NCHW memory format. + """ + if node is None: + return False + if [node.target] not in CONV_OP_PATTERNS: + return False + + tensor = get_first_fake_tensor(node) + if tensor is None: + return False + + return not tensor.is_contiguous(memory_format=torch.channels_last) + def __init__(self) -> None: - quantizers: List[OperatorConfigQuantizer] = [ + quantizers: List[Quantizer] = [ OperatorConfigQuantizer( INT8_BINARY_OPS_OPERATOR_CONFIG, filter_fn=self.broadcasting_filter ), OperatorConfigQuantizer(INT8_LINEAR_OPERATOR_CONFIG), + OperatorConfigQuantizer( + INT8_CONV_OPERATOR_CONFIG, filter_fn=self.nchw_filter + ), + InputQuantizer(INT8_PER_TENSOR_CONFIG), + OutputQuantizer(INT8_PER_TENSOR_CONFIG), + SharedQspecQuantizer(), ] super().__init__(quantizers) @@ -101,7 +127,6 @@ def check_pattern( Returns the matched nodes if the given node matches the given pattern, otherwise None. """ match: List[Node] = [] - node = list(node.users)[0] if node and len(node.users) > 0 else None for pattern_target in pattern: if self.check_node(node, pattern_target): @@ -183,7 +208,7 @@ def annotate_match( config.input_activation if config else None ) - if all(node not in match for node in node.users): + if all(node not in match for node in node.users) and output_qspec is None: output_qspec = config.output_activation if config else None node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( @@ -197,3 +222,171 @@ def annotate(self, model: GraphModule) -> None: def validate(self, model: GraphModule) -> bool: return True + + +class InputQuantizer(Quantizer): + """ + Quantizes only the input activations of the graph. + """ + + def __init__( + self, + quantization_config: QuantizationConfig, + filter_fn: Callable[[Node], bool] = lambda node: False, + ) -> None: + self.quantization_config = quantization_config + self.filter_fn = filter_fn + + def annotate(self, model: GraphModule) -> None: + for node in model.graph.nodes: + is_placeholder = node.op == "placeholder" + is_filtered = self.filter_fn(node) + if is_placeholder and not is_filtered: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + {}, self.quantization_config.output_activation + ) + + def validate(self, model: GraphModule) -> bool: + return True + + +class OutputQuantizer(Quantizer): + """ + Quantizes only the output activations of the graph. + """ + + def __init__( + self, + quantization_config: QuantizationConfig, + filter_fn: Callable[[Node], bool] = lambda node: False, + ) -> None: + self.quantization_config = quantization_config + self.filter_fn = filter_fn + + def annotate(self, model: GraphModule) -> None: + output_node = model.graph.output_node() + input_qspec_map = { + n: self.quantization_config.input_activation + for n in output_node.all_input_nodes + if not self.filter_fn(n) + } + output_qspec = self.quantization_config.output_activation + output_node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map, output_qspec + ) + + def validate(self, model: GraphModule) -> bool: + return True + + +class SharedQspecQuantizer(Quantizer): + """ + Special quantizer for assuring that given ops share the same quantization parameters on all input and outputs, + i.e. ops which does not change the scale such as clone, min/max, transposes and so on. + + Args: + targets (Optional[List[OpOverload]]): List of operator overloads to apply shared quantization spec to. + If None, a default list of supported ops is used. + """ + + SHARED_QSPEC_OPS_DEFAULT: List[OpOverload] = [ + # Clone + torch.ops.aten.clone.default, + torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten.detach_.default, + # Min/Max/Mean + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + # Data shuffling + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_copy.int, + torch.ops.aten.t_copy.default, + torch.ops.aten.t.default, + # Change shape + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + torch.ops.aten.view_as.default, + torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten.unflatten.int, + torch.ops.aten.flatten.using_ints, + ] + + def __init__(self, targets: Optional[List[OpOverload]] = None) -> None: + super().__init__() + if targets is None: + self.targets = self.SHARED_QSPEC_OPS_DEFAULT + else: + self.targets = targets + + def _is_annotated(self, node: Node) -> bool: + return Q_ANNOTATION_KEY in node.meta + + def _annotate_shared_cluster(self, root_node: Node) -> None: + """ + Finds a cluster of unannotated nodes starting in root_node and annotates them with a common + SharedQuantizationSpec. + """ + + shared_nodes = set() + leaf_nodes = set() + bfs_queue = [root_node] + + while bfs_queue: + node = bfs_queue.pop(0) + + if self._is_annotated(node): + leaf_nodes.add(node) + continue + if node.op == "get_attr": + continue + + if node.target not in self.targets: + raise NotImplementedError( + ( + f"{SharedQspecQuantizer.__name__} found unannoted node '{node.name}' in neighbour_nodes " + "which is not in the supported target list. This might be the case either because:\n" + "1) The op should have shared qspec but is not in the target list. " + "In this case, try modifying the list using the targets field in the initializer.\n" + "2) The op should not be quantized, which is not currently supported by the SharedQspecQuantizer." + ) + ) + + shared_nodes.add(node) + neighbour_nodes = list(node.all_input_nodes) + list(node.users) + for n in neighbour_nodes: + if n not in shared_nodes: + bfs_queue.append(n) + + # The selection of root node for the shared_qspec is important for + # torchao.quantization.pt2e.prepare._create_obs_or_fq_from_qspec: + # 1. For regular QuantizationSpecs, it creates a new observer + # 2. For SharedQuantizationSpecs, it returns the observer created for it's root node + # 3. It handles nodes in the order they appear in graph.nodes + # This means that the root node of the shared group needs to be the first annotated node that appears in graph.nodes. + shared_root_node = next(n for n in root_node.graph.nodes if n in leaf_nodes) + shared_qspec = SharedQuantizationSpec(shared_root_node) + + for node in shared_nodes: + input_qspec_map = {n: shared_qspec for n in node.all_input_nodes} + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map, shared_qspec + ) + + def annotate(self, model: GraphModule) -> None: + for node in model.graph.nodes: + if node.target in self.targets and not self._is_annotated(node): + self._annotate_shared_cluster(node) + + def validate(self, model: GraphModule) -> bool: + return True diff --git a/backends/cortex_m/test/TARGETS b/backends/cortex_m/test/TARGETS index b7a04f3efab..292a087a88a 100644 --- a/backends/cortex_m/test/TARGETS +++ b/backends/cortex_m/test/TARGETS @@ -8,13 +8,11 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") load("targets.bzl", "define_common_targets") oncall("executorch") - python_unittest( name="test_replace_quant_nodes", srcs=[ "test_helpers_passes_utils.py", "test_replace_quant_nodes.py", - "test_quantize_op_fusion_pass.py", ], deps=[ "//pytorch/ao:torchao", # @manual diff --git a/backends/cortex_m/test/misc/test_quantization.py b/backends/cortex_m/test/misc/test_quantization.py new file mode 100644 index 00000000000..d4f84e4f075 --- /dev/null +++ b/backends/cortex_m/test/misc/test_quantization.py @@ -0,0 +1,359 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class SharedQspecMulipleClusters(torch.nn.Module): + """Three linear shared qspec clusters.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 8, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 8, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 4, + } + + def forward(self, x): + x1 = torch.clone(x) + x2 = x1 + x1 + x3 = torch.clone(x2) + x3 = torch.clone(x3) + x3 = torch.clone(x3) + x4 = x3 + x3 + x5 = torch.transpose(x4, 2, 1) + return x5 + + +class SharedQspecInputForkNonShared(torch.nn.Module): + """Shared qspec cluster with an input fork with both inputs as non-shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + } + + def forward(self, x, y): + z = torch.maximum(x, y) + return torch.flatten(z) + + +class SharedQspecInputForkShared(torch.nn.Module): + """Shared qspec cluster with an input fork with both inputs as shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 5, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x, y): + x = torch.clone(x) + y = torch.permute(y, (0, 1, 3, 2)) + z = torch.minimum(x, y) + return z + + +class SharedQspecInputForkXShared(torch.nn.Module): + """Shared qspec cluster with an input fork with left input as shared qspec.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x, y): + x = torch.t_copy(x) + z = torch.maximum(x, y) + return z + + +class SharedQspecInputForkYShared(torch.nn.Module): + """Shared qspec cluster with an input fork with right input as shared qspec.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 5, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x, y): + y = torch.clone(y) + z = torch.minimum(x, y) + return torch.squeeze(z) + + +class SharedQspecInputForkXConstant(torch.nn.Module): + """Shared qspec cluster with an input fork with left input as global constant.""" + + ops_before_transforms = {} + ops_after_transforms = {} + constant = torch.tensor(5.0) + + def forward(self, x): + return torch.minimum(self.constant, x) + + +class SharedQspecInputForkYConstant(torch.nn.Module): + """Shared qspec cluster with an input fork with left input as local constant.""" + + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x): + return torch.maximum(x, torch.tensor(5.0)) + + +class SharedQspecOutputForkNonShared(torch.nn.Module): + """Shared qspec cluster with an output fork with both outputs as non-shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + } + + def forward(self, x): + x = torch.unsqueeze(x, 0) + y = x + x + return x, y + + +class SharedQspecOutputForkShared(torch.nn.Module): + """Shared qspec cluster with an output fork with both outputs as shared qspecs.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 6, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x): + x = torch.unsqueeze(x, 0) + y = torch.clone(x) + z = torch.permute_copy(x, (0, 2, 1, 3)) + return y, z, x + + +class SharedQspecManyForks(torch.nn.Module): + """Shared qspec cluster with a number of forks to testmore complex structures.""" + + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 2, + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 9, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x): + x1 = torch.clone(x) + x2 = torch.maximum(x, x1) + x3 = torch.maximum(x, torch.t(x2)) + x4 = torch.minimum(x2, x3) + + return x4 + + +class SharedQspecSurroundedQuantizedOp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + } + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, + } + + def forward(self, x): + x1 = torch.clone(x) + x2 = torch.add(x1, x1) + x3 = torch.maximum(x1, x2) + return x3 + + +class SharedQspecSurroundedQuantizedOpConstant(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x): + x1 = torch.clone(x) + x2 = torch.add(x1, torch.ones(2, 2)) + x3 = torch.maximum(x1, x2) + return x3 + + +class SharedQspecSub(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x, y): + return torch.clone(x - y) + + +test_cases = { + "multiple_clusters": McuTestCase( + SharedQspecMulipleClusters(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "input_fork_non_shared": McuTestCase( + SharedQspecInputForkNonShared(), + (ramp_tensor(-2, 2, (2, 3, 4)), ramp_tensor(-1, 3, (2, 3, 4))), + ), + "input_fork_shared": McuTestCase( + SharedQspecInputForkShared(), + (ramp_tensor(-2, 2, (2, 3, 4, 5)), ramp_tensor(-1, 3, (2, 3, 5, 4))), + ), + "input_fork_x_shared": McuTestCase( + SharedQspecInputForkXShared(), + (ramp_tensor(-2, 2, (3, 4)), ramp_tensor(-1, 3, (4, 3))), + ), + "input_fork_y_shared": McuTestCase( + SharedQspecInputForkYShared(), + (ramp_tensor(-2, 2, (2, 3, 4)), ramp_tensor(-1, 3, (2, 3, 4))), + ), + "input_fork_x_constant": McuTestCase( + SharedQspecInputForkXConstant(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "input_fork_y_constant": McuTestCase( + SharedQspecInputForkYConstant(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "surrounded_quantized_op": McuTestCase( + SharedQspecSurroundedQuantizedOp(), + (ramp_tensor(-128, 2, (2, 3, 4)),), + ), + "surrounded_quantized_op_constant": McuTestCase( + SharedQspecSurroundedQuantizedOpConstant(), + (ramp_tensor(-2, 2, (2, 2)),), + ), + "output_fork_non_shared": McuTestCase( + SharedQspecOutputForkNonShared(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "output_fork_shared": McuTestCase( + SharedQspecOutputForkShared(), + (ramp_tensor(-2, 2, (2, 3, 4)),), + ), + "many_forks": McuTestCase( + SharedQspecManyForks(), + (ramp_tensor(-20, 2, (4, 4)),), + ), + "non-quantized_op": McuTestCase( + SharedQspecSub(), + (ramp_tensor(0, 10, (5, 5)), ramp_tensor(0, 1, (5, 5))), + ), +} + +xfails = { + "surrounded_quantized_op_constant": "Numerical error since the add is forced to have non-correct qparams.", + "non-quantized_op": "Non-quantized ops are not currently supported in SharedQspecQuantizer.", +} + + +@parametrize("test_case", test_cases, xfails=xfails) +def test_shared_qspec_quantizer(test_case): + """ + Test that ops which does not change dynamic range are able to use int8 portable kernels. + """ + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + # Check that all nodes in the graph are in int8 + artifact = tester.get_artifact() + for node in artifact.exported_program().module().graph.nodes: + if node.op != "call_function": + continue + if node.target == exir_ops.edge.cortex_m.dequantize_per_tensor.default: + continue + + assert get_first_fake_tensor(node).dtype == torch.int8, f"{node.name}" diff --git a/backends/cortex_m/test/ops/test_activation.py b/backends/cortex_m/test/ops/test_activation.py new file mode 100644 index 00000000000..bc20d364674 --- /dev/null +++ b/backends/cortex_m/test/ops/test_activation.py @@ -0,0 +1,409 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMLinearReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=4, out_features=3): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + +class CortexMLinearHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.25, max_val=0.75): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearReLU6(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(self.linear(x)) + + +class CortexMLinearReLUInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=8, out_features=8): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.linear(x)) + + +class CortexMLinearHardtanhInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-1.0, max_val=1.0): + super().__init__() + self.linear = torch.nn.Linear(8, 8, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=True) + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMLinearHardsigmoid(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, in_features=6, out_features=6): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=False) + self.act = torch.nn.Hardsigmoid() + + def forward(self, x): + return self.act(self.linear(x)) + + +class CortexMConv2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + +class CortexMConv2DReLU6(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 3, stride=2, padding=1, bias=False) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(self.conv(x)) + + +class CortexMConv2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-2.0, max_val=2.0): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=True) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DReLUInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.conv(x)) + + +class CortexMConv2DHardtanhInplace(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=False) + self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=True) + torch.nn.init.ones_(self.conv.weight) + + def forward(self, x): + return self.act(self.conv(x)) + + +class CortexMConv2DHardsigmoid(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardsigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1, bias=False) + self.act = torch.nn.Hardsigmoid(inplace=True) + self.conv.weight.data.fill_(1) + + def forward(self, x): + return self.act(self.conv(x)) + + +test_cases = { + # Linear + activation tests with various data ranges + "linear_relu_small_range": McuTestCase( + model=CortexMLinearReLU(), + example_inputs=(ramp_tensor(-10, 10, (1, 4)),), + ), + "linear_relu_large_range": McuTestCase( + model=CortexMLinearReLU(in_features=16, out_features=16), + example_inputs=(ramp_tensor(-100, 100, (2, 16)),), + ), + "linear_relu_negative": McuTestCase( + model=CortexMLinearReLU(in_features=8, out_features=8), + example_inputs=(ramp_tensor(-50, 0, (1, 8)),), + ), + "linear_relu6": McuTestCase( + model=CortexMLinearReLU6(), + example_inputs=(ramp_tensor(-2, 10, (1, 8)),), + ), + "linear_relu_inplace": McuTestCase( + model=CortexMLinearReLUInplace(), + example_inputs=(ramp_tensor(-5, 5, (2, 8)),), + ), + "linear_hardtanh_symmetric": McuTestCase( + model=CortexMLinearHardtanh(min_val=-0.5, max_val=0.5), + example_inputs=(ramp_tensor(-1, 1, (2, 1, 4)),), + ), + "linear_hardtanh_asymmetric": McuTestCase( + model=CortexMLinearHardtanh(min_val=-1.5, max_val=0.25), + example_inputs=(ramp_tensor(-2, 1, (1, 4)),), + ), + "linear_hardtanh_large_range": McuTestCase( + model=CortexMLinearHardtanh(min_val=-10.0, max_val=10.0), + example_inputs=(ramp_tensor(-20, 20, (2, 4)),), + ), + "linear_hardtanh_inplace": McuTestCase( + model=CortexMLinearHardtanhInplace(min_val=-0.75, max_val=0.75), + example_inputs=(ramp_tensor(-2, 2, (1, 8)),), + ), + # Convolution + activation tests with various configurations + "conv2d_relu_small_kernel": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu_large_range": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-50, 50, (2, 4, 16, 16)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu6_stride": McuTestCase( + model=CortexMConv2DReLU6(), + example_inputs=( + ramp_tensor(-10, 20, (1, 3, 12, 12)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu_inplace": McuTestCase( + model=CortexMConv2DReLUInplace(), + example_inputs=( + ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_narrow": McuTestCase( + model=CortexMConv2DHardtanh(min_val=-0.5, max_val=0.5), + example_inputs=( + ramp_tensor(-2, 2, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_wide": McuTestCase( + model=CortexMConv2DHardtanh(min_val=-5.0, max_val=5.0), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_hardtanh_inplace": McuTestCase( + model=CortexMConv2DHardtanhInplace(min_val=-10.0, max_val=10.0), + example_inputs=( + ramp_tensor(-15, 15, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "linear_hardsigmoid": McuTestCase( + model=CortexMLinearHardsigmoid(in_features=6, out_features=4), + example_inputs=(ramp_tensor(-8, 8, (2, 6)),), + ), + "conv2d_hardsigmoid_inplace": McuTestCase( + model=CortexMConv2DHardsigmoid(), + example_inputs=( + ramp_tensor(-4, 4, (1, 1, 6, 6)).to(memory_format=torch.channels_last), + ), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_activation(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize("test_case", test_cases) +def test_implementation_activation(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 4389b463076..8c355fd2e39 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -59,17 +59,6 @@ class CortexMTensorAdd(Model): } -class CortexMTensorAddBroadcast(Model): - # TODO: Quantize and accelerate broadcasted adds - ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - - ops_after_transforms = { - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, - } - - class CortexMAlphaAdd(ModelAlpha): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, @@ -126,15 +115,15 @@ class CortexMAlphaAdd(ModelAlpha): (torch.rand(2, 2) * 10, torch.rand(2, 2)), ), "broadcast_1": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - CortexMTensorAddBroadcast(), + CortexMTensorAdd(), ( ramp_tensor(-2, 2, (2, 1, 2, 1)), ramp_tensor(-5, 5, (1, 2, 1, 2)), @@ -150,7 +139,7 @@ class CortexMAlphaAdd(ModelAlpha): } -dialect_xfails = { +xfails = { "self_scalar": ( "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", AttributeError, @@ -163,10 +152,13 @@ class CortexMAlphaAdd(ModelAlpha): "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", AssertionError, ), + "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", } -@parametrize("test_case", test_cases, xfails=dialect_xfails) +@parametrize("test_case", test_cases, xfails=xfails) def test_dialect_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( @@ -174,23 +166,7 @@ def test_dialect_add(test_case): ) -implementation_xfails = { - "self_scalar": ( - "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", - AttributeError, - ), - "scalar_scalar": ( - "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", - AttributeError, - ), - "alpha": ( - "Expecting kwargs for aten op IR to be empty - alpha arg not supported.", - AssertionError, - ), -} - - -@parametrize("test_case", test_cases, xfails=implementation_xfails) +@parametrize("test_case", test_cases, xfails=xfails) def test_implementation_add(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py new file mode 100644 index 00000000000..8a67d1b7de1 --- /dev/null +++ b/backends/cortex_m/test/ops/test_conv.py @@ -0,0 +1,212 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMConv1D(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv1d(*args, **kwargs, bias=False) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2D(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=False) + self.conv.weight.data.fill_(1.0) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2DBias(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=True) + + def forward(self, x): + + return self.conv(x) + + +class CortexMConv3D(torch.nn.Module): + ops_before_transforms = {} + + ops_after_transforms = {} + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv3d(*args, **kwargs, bias=False) + self.conv.weight.data.fill_(2.0) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2Dx3(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False) + self.conv2 = torch.nn.Conv2d(8, 16, 3, padding=1, bias=False) + self.conv3 = torch.nn.Conv2d(16, 8, 3, padding=1, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + + +# in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode +test_cases = { + "conv2d": McuTestCase( + model=CortexMConv2D(2, 4, 3), + example_inputs=( + ramp_tensor(1, 5, (1, 2, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_stride": McuTestCase( + model=CortexMConv2D(3, 4, (1, 2), stride=2), + example_inputs=( + ramp_tensor(-100, 10, (3, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_padding": McuTestCase( + model=CortexMConv2D(3, 2, 3, padding=(4, 1)), + example_inputs=( + ramp_tensor(0, 1, (2, 3, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_dilation": McuTestCase( + model=CortexMConv2D(1, 4, 3, dilation=(2, 2)), + example_inputs=( + ramp_tensor(0, 10, (3, 1, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_groups": McuTestCase( + model=CortexMConv2D(4, 4, 1, groups=2), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_bias_ch_out_1": McuTestCase( + model=CortexMConv2DBias(5, 1, 1), + example_inputs=( + ramp_tensor(0, 10, (2, 5, 3, 3)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_bias_ch_out_4": McuTestCase( + model=CortexMConv2DBias(5, 4, (1, 2)), + example_inputs=( + ramp_tensor(-3, 3, (2, 5, 10, 10)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_nchw": McuTestCase( + model=CortexMConv2D(5, 5, 1), + example_inputs=(ramp_tensor(0, 10, (1, 5, 8, 8)),), + ), + "conv1d": McuTestCase( + model=CortexMConv1D(1, 1, 1), + example_inputs=(ramp_tensor(0, 10, (1, 3, 2)),), + ), + "conv3d": McuTestCase( + model=CortexMConv3D(1, 1, 1), + example_inputs=( + ramp_tensor(-1000, 1000, (2, 1, 3, 3, 3)).to( + memory_format=torch.channels_last_3d + ), + ), + ), + "conv2d_x3": McuTestCase( + model=CortexMConv2Dx3(), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), +} + + +xfails_dialect = { + "conv2d_dilation": "NotImplementedError: 'slow_conv_dilated<>' not implemented for 'Int'", + "conv1d": "Currently not supported.", + "conv2d_nchw": "Currently not supported.", + "conv3d": "Currently not supported.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_dialect) +def test_dialect_conv2d(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +xfails_implementation = { + "conv1d": "Currently not supported.", + "conv2d_nchw": "Currently not supported.", + "conv3d": "Currently not supported.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_implementation) +def test_implementation_conv2d(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_linear.py b/backends/cortex_m/test/ops/test_linear.py index 4ab5ca99f15..e81daa7e83e 100644 --- a/backends/cortex_m/test/ops/test_linear.py +++ b/backends/cortex_m/test/ops/test_linear.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. -import pytest import torch +from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import ( CortexMTester, McuTestCase, @@ -13,12 +13,9 @@ ) -class CortexMMm(torch.nn.Module): - def forward(self, x, y): - return torch.mm(x, y) - +class CortexMLinear(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_mm_default": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, } @@ -29,32 +26,45 @@ def forward(self, x, y): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + return self.linear(x) -class CortexMBmm(torch.nn.Module): - def forward(self, x, y): - return torch.bmm(x, y) +class CortexMLinearX3(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_bmm_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_aten_linear_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, } ops_after_transforms = { - "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 3, "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } + def __init__(self, *args, **kwargs): + super().__init__() + self.linear = torch.nn.Linear(*args, bias=False) + self.linear.weight.data.fill_(1.0) + + def forward(self, x): + x = self.linear(x) + x = self.linear(x) + x = self.linear(x) + return x -class CortexMAddmm(torch.nn.Module): - def forward(self, x, y, z, alpha=None, beta=None): - return torch.addmm(beta, x, alpha, y, z) +class CortexMLinearBias(torch.nn.Module): ops_before_transforms = { - "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, } ops_after_transforms = { @@ -63,90 +73,23 @@ def forward(self, x, y, z, alpha=None, beta=None): "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, } - -class CortexMAt(CortexMMm): - def forward(self, x, y): - return x @ y - - -class CortexMMatmul(CortexMMm): - def forward(self, x, y): - return torch.matmul(x, y) - - -class CortexMLinear(CortexMMatmul): - def __init__(self, *args, **kwargs): - super().__init__() - self.linear = torch.nn.Linear(*args, bias=False) - - def forward(self, x): - return self.linear(x) - - -class CortexMLinearBias(CortexMAddmm): def __init__(self, *args, **kwargs): super().__init__() self.linear = torch.nn.Linear(*args, bias=True) self.relu = torch.nn.ReLU() def forward(self, x): - return self.relu(self.linear(x)) + return self.linear(x) test_cases = { - "mm": McuTestCase( - model=CortexMMm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "bmm": McuTestCase( - model=CortexMBmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16, 16)), - ramp_tensor(0, 10, (1, 16, 16)), - ), - ), - "addmm": McuTestCase( - model=CortexMAddmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ramp_tensor(0, 10, (16, 16)), - 2, - 4, - ), - ), - "addmm_scalars": McuTestCase( - model=CortexMAddmm(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "@-operator": McuTestCase( - model=CortexMAt(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), - "matmul": McuTestCase( - model=CortexMMatmul(), - example_inputs=( - ramp_tensor(0, 10, (1, 16)), - ramp_tensor(0, 10, (16, 16)), - ), - ), "linear_rank1": McuTestCase( - model=CortexMLinear(2, 3), - example_inputs=(ramp_tensor(-1, 1, (2,)),), + model=CortexMLinear(1, 2), + example_inputs=(torch.Tensor([1]),), ), "linear_rank2_pos": McuTestCase( - model=CortexMLinear(8, 3), - example_inputs=(ramp_tensor(0, 10, (2, 8)),), + model=CortexMLinear(1, 2), + example_inputs=(ramp_tensor(-1, 1, (1, 1)),), ), "linear_rank3_neg": McuTestCase( model=CortexMLinear(5, 3), @@ -164,22 +107,24 @@ def forward(self, x): model=CortexMLinearBias(61, 37), example_inputs=(ramp_tensor(0, 10, (8, 61)),), ), + "linear_x3": McuTestCase( + model=CortexMLinearX3(4, 4), + example_inputs=(ramp_tensor(0, 10, (2, 4)),), + ), } -@pytest.mark.skip( - reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." -) +@parametrize("test_case", test_cases) def test_dialect_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( - test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, ) -@pytest.mark.skip( - reason="Skipping until the quantized_linear_fusion_pass is updated to work with non decomposed linear ops." -) +@parametrize("test_case", test_cases) def test_implementation_linear(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation() + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_maximum.py b/backends/cortex_m/test/ops/test_maximum.py new file mode 100644 index 00000000000..58d477a9516 --- /dev/null +++ b/backends/cortex_m/test/ops/test_maximum.py @@ -0,0 +1,83 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMTensorMaximum(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_maximum_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_maximum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x, y): + return torch.maximum(x, y) + + +test_cases = { + "tensor_small": McuTestCase( + CortexMTensorMaximum(), + ( + torch.tensor([[1.0, -2.0], [3.5, -4.5]]), + torch.tensor([[0.5, -1.0], [4.0, -3.5]]), + ), + ), + "tensor_rand": McuTestCase( + CortexMTensorMaximum(), + ( + torch.rand(2, 2, 2) * 4 - 2, + torch.rand(2, 2, 2) * 4 - 2, + ), + ), + "broadcast": McuTestCase( + CortexMTensorMaximum(), + ( + ramp_tensor(-2, 2, (2, 1, 2)), + ramp_tensor(-3, 3, (1, 2, 1)), + ), + ), + "broadcast_rank4": McuTestCase( + CortexMTensorMaximum(), + ( + ramp_tensor(-4, 4, (1, 2, 3, 1)), + ramp_tensor(-6, 6, (4, 1, 1, 3)), + ), + ), + "broadcast_scalar": McuTestCase( + CortexMTensorMaximum(), + ( + torch.tensor(1.0), + ramp_tensor(-6, 6, (4, 1, 1, 3)), + ), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_maximum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@parametrize("test_case", test_cases) +def test_implementation_maximum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_minimum.py b/backends/cortex_m/test/ops/test_minimum.py new file mode 100644 index 00000000000..633ccdbf483 --- /dev/null +++ b/backends/cortex_m/test/ops/test_minimum.py @@ -0,0 +1,104 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMSelfMinimum(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.minimum(x, x) + + +class CortexMTensorMinimum(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_minimum_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x, y): + return torch.minimum(x, y) + + +test_cases = { + "self_rank_1": McuTestCase( + CortexMSelfMinimum(), + (ramp_tensor(-5, 5, (10,)),), + ), + "self_rank_3": McuTestCase( + CortexMSelfMinimum(), + (ramp_tensor(-10, 10, (2, 3, 4)),), + ), + "tensor_small": McuTestCase( + CortexMTensorMinimum(), + ( + torch.tensor([[1.0, -2.0], [3.5, -4.5]]), + torch.tensor([[0.5, -3.0], [3.0, -4.0]]), + ), + ), + "tensor_rand": McuTestCase( + CortexMTensorMinimum(), + ( + torch.rand(2, 2, 2) * 4 - 2, + torch.rand(2, 2, 2) * 4 - 2, + ), + ), + "broadcast": McuTestCase( + CortexMTensorMinimum(), + ( + ramp_tensor(-2, 2, (2, 1, 2)), + ramp_tensor(-3, 3, (1, 2, 1)), + ), + ), + "broadcast_rank4": McuTestCase( + CortexMTensorMinimum(), + ( + ramp_tensor(-4, 4, (1, 2, 3, 1)), + ramp_tensor(-6, 6, (4, 1, 1, 3)), + ), + ), +} + + +xfails = {} + + +@parametrize("test_case", test_cases, xfails=xfails) +def test_dialect_minimum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +@parametrize("test_case", test_cases, xfails=xfails) +def test_implementation_minimum(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/ops/test_mul.py b/backends/cortex_m/test/ops/test_mul.py index a2f13760bf0..35c958ce8d4 100644 --- a/backends/cortex_m/test/ops/test_mul.py +++ b/backends/cortex_m/test/ops/test_mul.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. -import pytest import torch from executorch.backends.arm.test.common import parametrize from executorch.backends.cortex_m.test.tester import ( @@ -91,11 +90,11 @@ class CortexMTensorMul(Model): ), "tensor_scalar": McuTestCase( CortexMScalarMul(), - (torch.ones(2, 2), 1.0), + (torch.ones(1), 1.0), ), "scalar_tensor": McuTestCase( CortexMScalarMul(), - (1000.0, torch.ones(2, 2)), + (1000.0, torch.ones(1)), ), "broadcast_1": McuTestCase( CortexMTensorMul(), @@ -115,17 +114,32 @@ class CortexMTensorMul(Model): } -@pytest.mark.skip(reason="Not implemented yet") -@parametrize("test_case", test_cases) +xfail_cases = { + "self_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "scalar_scalar": ( + "'float' object has not attribute 'fake_mode' - scalar only ops not supported.", + AttributeError, + ), + "broadcast_1": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_2": "Broadcasting not yet supported in Cortex-M backend", + "broadcast_3": "Broadcasting not yet supported in Cortex-M backend", +} + + +@parametrize("test_case", test_cases, xfails=xfail_cases) def test_dialect_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) tester.test_dialect( - test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, ) -@pytest.mark.skip(reason="Not implemented yet") -@parametrize("test_case", test_cases) +@parametrize("test_case", test_cases, xfails=xfail_cases) def test_implementation_mul(test_case): tester = CortexMTester(test_case.model, test_case.example_inputs) - tester.test_implementation() + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/ops/test_transpose.py b/backends/cortex_m/test/ops/test_transpose.py new file mode 100644 index 00000000000..de16c2f81ad --- /dev/null +++ b/backends/cortex_m/test/ops/test_transpose.py @@ -0,0 +1,102 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + +OPS_BEFORE_PASSES = { + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, +} + +OPS_AFTER_PASSES = { + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, +} + + +class CortexMPermute(torch.nn.Module): + ops_before_transforms = OPS_BEFORE_PASSES + ops_after_transforms = OPS_AFTER_PASSES + + def __init__(self, perms): + super().__init__() + self.perms = perms + + def forward(self, x): + return x.permute(self.perms) + + +class CortexMTranspose(torch.nn.Module): + ops_before_transforms = OPS_BEFORE_PASSES + ops_after_transforms = OPS_AFTER_PASSES + + def __init__(self, dim0, dim1): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + return x.transpose(self.dim0, self.dim1) + + +class CortexMT(torch.nn.Module): + ops_before_transforms = OPS_BEFORE_PASSES + ops_after_transforms = OPS_AFTER_PASSES + + def forward(self, x): + return x.t() + + +test_cases = { + "permute_nhwc_to_nchw": McuTestCase( + CortexMPermute((0, 3, 1, 2)), + (ramp_tensor(-0.5, 0.5, (2, 3, 4, 2)),), + ), + "permute_nchw_to_nhwc_neg_index": McuTestCase( + CortexMPermute((0, -2, -1, -3)), + (ramp_tensor(10, 100, (2, 3, 4, 2)),), + ), + "permute_rank_1": McuTestCase( + CortexMPermute((0,)), + (ramp_tensor(10, 100, (3)),), + ), + "transpose_1_2": McuTestCase( + CortexMTranspose(1, 2), + (ramp_tensor(-1.0, 1.0, (1, 3, 4)),), + ), + "transpose_0_1": McuTestCase( + CortexMTranspose(0, 1), + (ramp_tensor(-2.0, 2.0, (2, 3, 4, 3)),), + ), + "t_operator": McuTestCase( + CortexMT(), + (ramp_tensor(-0.5, 0.5, (4, 2)),), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_transpose(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +@parametrize("test_case", test_cases) +def test_implementation_transpose(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1) diff --git a/backends/cortex_m/test/test_quantize_op_fusion_pass.py b/backends/cortex_m/test/test_quantize_op_fusion_pass.py deleted file mode 100644 index 95845597947..00000000000 --- a/backends/cortex_m/test/test_quantize_op_fusion_pass.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import executorch -import executorch.backends.cortex_m.ops.operators # noqa - -import torch - -from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( - QuantizedOpFusionPass, -) -from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( - ReplaceQuantNodesPass, -) -from executorch.backends.cortex_m.test.test_helpers_passes_utils import ( - AddQuantizer, - check_count, - get_node_args, -) -from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import export -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - - -class TestQuantizedOpFusionPass(unittest.TestCase): - """ - Test suite for the QuantizedOpFusionPass which fuses dequantize->add->quantize patterns - into a single quantized_add operation with AoT-computed parameters. - """ - - def setUp(self): - """Set up common test fixtures""" - self.example_inputs = (torch.randn(4, 8), torch.randn(4, 8)) - - def _prepare_quantized_model(self, model_class): - """Helper to prepare a quantized model for testing""" - model = model_class() - - # Export and quantize - exported_model = export(model.eval(), self.example_inputs, strict=True).module() - prepared_model = prepare_pt2e(exported_model, AddQuantizer()) - quantized_model = convert_pt2e(prepared_model) - - # Export to EXIR Edge - exported = export(quantized_model, self.example_inputs, strict=True) - edge_program = executorch.exir.to_edge( - exported, - compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), - ) - return edge_program - - def _apply_passes(self, edge_program): - """Apply both ReplaceQuantNodesPass and QuantizedOpFusionPass""" - passes = [QuantizedOpFusionPass(), ReplaceQuantNodesPass()] - final_program = edge_program.transform(passes) - return final_program - - def test_single_add_fusion(self): - """Single add with full Q/DQ pattern should fuse into one quantized_add node""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - edge_graph = edge_program.exported_program().graph_module - - # Get reference output - reference_output = edge_graph(*self.example_inputs) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify fusion occurred - check_count( - transformed_graph, - exir_ops.edge.cortex_m.quantized_add.default, - 1, # Should have exactly 1 fused quantized_add - ) - - # Verify the following - # Before fusion: - # x --> quantize_per_tensor --> dequantize_per_tensor --> add --> quantize_per_tensor --> - # dequantize_per_tensor --> output y --> quantize_per_tensor --> dequantize_per_tensor --^ - # After fusion: - # x --> quantize_per_tensor --> quantized_add --> dequantize_per_tensor --> output - # y --> quantize_per_tensor --^ - check_count( - transformed_graph, exir_ops.edge.cortex_m.quantize_per_tensor.default, 2 - ) - check_count( - transformed_graph, exir_ops.edge.cortex_m.dequantize_per_tensor.default, 1 - ) - check_count(transformed_graph, exir_ops.edge.cortex_m.quantized_add.default, 1) - - # Verify numerical equivalence - fused_output = transformed_graph(*self.example_inputs) - torch.testing.assert_close(reference_output, fused_output, rtol=1e-3, atol=1e-3) - - def test_multiple_add_fusion(self): - """Multiple independent adds should create multiple quantized_add nodes""" - - class MultipleAddModel(torch.nn.Module): - def forward(self, x, y): - z1 = x + y # First add - z2 = x + z1 # Second add - return z2 - - # Prepare model - edge_program = self._prepare_quantized_model(MultipleAddModel) - edge_graph = edge_program.exported_program().graph_module - - # Get reference output - reference_output = edge_graph(*self.example_inputs) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify multiple fusions occurred - check_count( - transformed_graph, - exir_ops.edge.cortex_m.quantized_add.default, - 2, # Should have 2 fused quantized_add nodes - ) - - # Verify numerical equivalence - fused_output = transformed_graph(*self.example_inputs) - torch.testing.assert_close(reference_output, fused_output, rtol=1e-3, atol=1e-3) - - def test_no_fusion_without_pattern(self): - """Add without proper Q/DQ pattern should not be fused""" - - class NonQuantizedAddModel(torch.nn.Module): - def forward(self, x, y): - # This will have add but not the full Q/DQ pattern after quantization - return torch.relu(x + y) # ReLU breaks the pattern - - # For this test, we'll create a model that doesn't have the complete pattern - # We need to manually construct a graph that has add without full Q/DQ - - model = NonQuantizedAddModel() - exported = export(model, self.example_inputs, strict=True) - edge_program = executorch.exir.to_edge( - exported, - compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), - ) - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify no fusion occurred - check_count( - transformed_graph, - exir_ops.edge.cortex_m.quantized_add.default, - 0, # Should have no fused quantized_add nodes - ) - - def test_precomputed_parameters(self): - """Fused node should have precomputed multipliers/shifts instead of scales""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Get arguments of the fused quantized_add node - quantized_add_args = get_node_args( - transformed_graph, exir_ops.edge.cortex_m.quantized_add.default - ) - - # Should have exactly one quantized_add node - self.assertEqual(len(quantized_add_args), 1) - args = quantized_add_args[0] - - # Verify argument structure: (tensor1, zp1, mult1, shift1, tensor2, zp2, mult2, shift2, out_zp, out_mult, out_shift) - self.assertEqual(len(args), 11, "quantized_add should have 11 arguments") - - # Check that multipliers and shifts are integers (not floats/scales) - # args[2], args[3] = input1 multiplier, shift - # args[6], args[7] = input2 multiplier, shift - # args[9], args[10] = output multiplier, shift - for i in [2, 3, 6, 7, 9, 10]: # multiplier and shift positions - self.assertIsInstance( - args[i], int, f"Argument {i} should be an integer (precomputed)" - ) - - def test_mixed_fusion_pattern(self): - """Mixed pattern (some fusable, some not) should partially fuse""" - - class MixedModel(torch.nn.Module): - def forward(self, x, y): - z1 = x + y # This should fuse - z2 = torch.relu(z1) # ReLU breaks next fusion - z3 = z2 + x # This won't have full Q/DQ pattern - return z3 - - # Prepare model - edge_program = self._prepare_quantized_model(MixedModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Should have partial fusion (at least 1, but not necessarily all adds) - quantized_add_count = sum( - 1 - for node in transformed_graph.graph.nodes - if node.op == "call_function" - and node.target == exir_ops.edge.cortex_m.quantized_add.default - ) - - self.assertGreaterEqual( - quantized_add_count, 1, "Should have at least 1 fused operation" - ) - - def test_different_tensor_shapes(self): - """Different tensor shapes should still fuse correctly""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Test with different input shapes - for shape in [(2, 3), (10, 20, 30), (1,)]: - with self.subTest(shape=shape): - inputs = (torch.randn(shape), torch.randn(shape)) - - model = SingleAddModel() - exported_model = export(model.eval(), inputs, strict=True).module() - prepared_model = prepare_pt2e(exported_model, AddQuantizer()) - quantized_model = convert_pt2e(prepared_model) - - exported = export(quantized_model, inputs, strict=True) - edge_program = executorch.exir.to_edge( - exported, - compile_config=executorch.exir.EdgeCompileConfig( - _check_ir_validity=False - ), - ) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Verify fusion occurred regardless of shape - check_count( - transformed_graph, exir_ops.edge.cortex_m.quantized_add.default, 1 - ) - - def test_aot_parameter_computation_accuracy(self): - """Verify that AoT-computed parameters match runtime computation""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - transformed_graph = transformed_program.exported_program().graph_module - - # Get the fused node arguments - quantized_add_args = get_node_args( - transformed_graph, exir_ops.edge.cortex_m.quantized_add.default - )[0] - - # Extract the computed multipliers and shifts - input1_mult, input1_shift = quantized_add_args[2], quantized_add_args[3] - input2_mult, input2_shift = quantized_add_args[6], quantized_add_args[7] - output_mult, output_shift = quantized_add_args[9], quantized_add_args[10] - - # Verify they are reasonable values - # Multipliers should be in int32 range - self.assertTrue(-(2**31) <= input1_mult < 2**31) - self.assertTrue(-(2**31) <= input2_mult < 2**31) - self.assertTrue(-(2**31) <= output_mult < 2**31) - - # Shifts should be reasonable (typically -31 to 31) - self.assertTrue(-50 <= input1_shift <= 50) - self.assertTrue(-50 <= input2_shift <= 50) - self.assertTrue(-50 <= output_shift <= 50) - - # Output multiplier should be close to 2^30 (for 1.0 scale) - self.assertAlmostEqual(output_mult, 2**30, delta=1000) - self.assertEqual(output_shift, -18) - - def test_executorch_program_generation(self): - """Verify ExecuTorch program generation with fused ops""" - - class SingleAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # Prepare model - edge_program = self._prepare_quantized_model(SingleAddModel) - - # Apply passes - transformed_program = self._apply_passes(edge_program) - - # Generate ExecuTorch program - executorch_program = transformed_program.to_executorch() - - # Verify the program contains the expected fused operator - operator_names = [ - op.name - for op in executorch_program.executorch_program.execution_plan[0].operators - ] - - self.assertIn("cortex_m::quantized_add", operator_names) - self.assertIn("cortex_m::quantize_per_tensor", operator_names) - self.assertIn("cortex_m::dequantize_per_tensor", operator_names) - # quantize_per_tensor --> dequantize_per_tensor --> add --> quantize_per_tensor --> dequantize_per_tensor - # (input quant) (dequant) (fp32 add) (re-quant) (dequant) - # ↓ - # Fusion Pass detects pattern: - # dequantize_per_tensor --> quantized_add (Fused node) --> quantize_per_tensor - - def test_broadcastable_shapes(self): - """Verify that broadcastable shapes are supported""" - - class BroadcastAddModel(torch.nn.Module): - def forward(self, x, y): - return x + y - - # input broadcastable shapes - inputs = (torch.randn(4, 1), torch.randn(4, 8)) - print(inputs) - - # Prepare quantized model - edge_program = self._prepare_quantized_model(BroadcastAddModel) - - # Get unfused output - unfused_graph = edge_program.exported_program().graph_module - unfused_output = unfused_graph(*inputs) - if isinstance(unfused_output, tuple): - unfused_output = unfused_output[0] - - # Apply fusion pass - fused_program = self._apply_passes(edge_program) - fused_graph = fused_program.exported_program().graph_module - fused_output = fused_graph(*inputs) - if isinstance(fused_output, tuple): - fused_output = fused_output[0] - - # Check fusion occurred - check_count(fused_graph, exir_ops.edge.cortex_m.quantized_add.default, 1) - - # Compare fused vs unfused (both quantized) - torch.testing.assert_close(fused_output, unfused_output, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - unittest.main() diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index 19de71444cd..70f91b3f1dc 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -22,6 +22,7 @@ ToEdge, ToExecutorch, ) + from executorch.exir import EdgeCompileConfig @@ -33,7 +34,13 @@ def __init__(self): class CortexMToEdge(ToEdge): def __init__(self): - config = EdgeCompileConfig(preserve_ops=[torch.ops.aten.linear.default]) + config = EdgeCompileConfig( + preserve_ops=[ + torch.ops.aten.linear.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + ] + ) super().__init__(config) diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index c95d34247be..ac97b9809bf 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -58,7 +58,9 @@ else() endif() # Link against ExecuTorch core libraries -target_link_libraries(cuda_tensor_maker PUBLIC executorch ${CMAKE_DL_LIBS}) +target_link_libraries( + cuda_tensor_maker PRIVATE executorch_core ${CMAKE_DL_LIBS} +) executorch_target_link_options_shared_lib(cuda_tensor_maker) install( @@ -67,50 +69,126 @@ install( DESTINATION lib ) -# CUDA-specific AOTI functionality -set(_aoti_cuda_sources - runtime/cuda_backend.cpp - runtime/shims/memory.cpp - runtime/shims/tensor_attribute.cpp - runtime/guard.cpp - runtime/shims/cuda_guard.cpp - runtime/shims/int4mm.cu - runtime/platform/platform.cpp +# Platform utilities (load_library, close_library, etc.) +set(_cuda_platform_sources runtime/platform/platform.cpp) +add_library(cuda_platform STATIC ${_cuda_platform_sources}) + +target_include_directories( + cuda_platform + PUBLIC $ $ + $ ) -add_library(aoti_cuda STATIC ${_aoti_cuda_sources}) + +target_compile_options( + cuda_platform + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) + +# Link against ExecuTorch core libraries +target_link_libraries(cuda_platform PRIVATE executorch_core ${CMAKE_DL_LIBS}) + +install( + TARGETS cuda_platform + EXPORT ExecuTorchTargets + DESTINATION lib +) + +# CUDA-specific AOTI shim symbols (dynamically linked) +set(_aoti_cuda_shim_sources + runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp + runtime/guard.cpp runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu + ${EXECUTORCH_ROOT}/backends/aoti/common_shims.cpp +) + +add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources}) + +# Define export macros for shared library +if(MSVC) + target_compile_definitions(aoti_cuda_shims PRIVATE EXPORT_AOTI_FUNCTIONS) + + # Ensure proper DLL import/export library naming on Windows + set_target_properties( + aoti_cuda_shims PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS OFF + ) +endif() + target_include_directories( - aoti_cuda + aoti_cuda_shims PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $ $ ) + target_compile_options( - aoti_cuda PUBLIC $<$:/EHsc /GR> - $<$>:-fexceptions -frtti -fPIC> + aoti_cuda_shims + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> ) + # Ensure symbols are exported properly target_link_options( - aoti_cuda PUBLIC $<$>:-Wl,--export-dynamic> + aoti_cuda_shims PUBLIC $<$>:-Wl,--export-dynamic> ) -# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and PyTorch -# CUDA libraries +# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and +# platform utilities target_link_libraries( - aoti_cuda PUBLIC aoti_common cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS} + aoti_cuda_shims + PRIVATE cuda_platform + PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS} ) -# If you need other CUDA libraries, link them similarly: -# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...) -executorch_target_link_options_shared_lib(aoti_cuda) -if(BUILD_TESTING) - add_executable(multimodal_benchmark tests/multimodal_benchmark.cpp) - target_link_libraries( - multimodal_benchmark PUBLIC aoti_cuda extension_module_static - extension_flat_tensor portable_ops_lib - ) +if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_shims) +endif() + +install( + TARGETS aoti_cuda_shims + EXPORT ExecuTorchTargets + DESTINATION lib +) + +# CUDA backend implementation +set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) + +# CUDA backend implementation +add_library(aoti_cuda_backend STATIC ${_aoti_cuda_backend_sources}) + +target_include_directories( + aoti_cuda_backend + PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $ + $ +) +target_compile_options( + aoti_cuda_backend + PUBLIC $<$:/EHsc /GR> + $<$>:-fexceptions -frtti -fPIC> +) +# Ensure symbols are exported properly +target_link_options( + aoti_cuda_backend PUBLIC + $<$>:-Wl,--export-dynamic> +) + +# Link against shims library and other dependencies On Windows (MSVC), use +# PRIVATE linkage for aoti_cuda_shims since the DLL is copied to the executable +# directory. On other platforms, use PUBLIC so the dependency propagates to +# consumers. +target_link_libraries( + aoti_cuda_backend PUBLIC cuda_platform extension_tensor cuda_tensor_maker + CUDA::cudart ${CMAKE_DL_LIBS} +) + +if(MSVC) + target_link_libraries(aoti_cuda_backend PRIVATE aoti_cuda_shims) +else() + target_link_libraries(aoti_cuda_backend PUBLIC aoti_cuda_shims) endif() +executorch_target_link_options_shared_lib(aoti_cuda_backend) + install( - TARGETS aoti_cuda + TARGETS aoti_cuda_backend EXPORT ExecuTorchTargets DESTINATION lib ) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index 94af87bbaed..3ae4eec6680 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -11,11 +11,13 @@ runtime.python_library( "//executorch/...", ], deps = [ + ":triton_replacement_pass", "//caffe2:torch", "//executorch/backends/aoti/passes:passes", "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/aoti:aoti_backend", ], ) @@ -32,3 +34,33 @@ runtime.python_library( "//executorch/backends/aoti:aoti_partitioner", ], ) + +runtime.python_library( + name = "triton_kernels", + srcs = [ + "triton/kernels/__init__.py", + "triton/kernels/sdpa.py", + ], + visibility = [ + "//executorch/backends/cuda/...", + ], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "triton_replacement_pass", + srcs = [ + "triton/__init__.py", + "triton/replacement_pass.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + ":triton_kernels", + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f8482835ea5..cc2d662b335 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -4,142 +4,63 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum - -from typing import Any, Dict, final, List, Optional, Set +from importlib import resources +from typing import Any, Dict, final, List import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, +from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, ) -from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.decomposition import conv1d_to_conv2d -from torch.export.passes import move_to_device_pass -from torch.nn.attention import SDPBackend - -cuda_decomposition_table = { - torch.ops.aten.conv1d.default: conv1d_to_conv2d, -} - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "at::_ops::_weight_int4pack_mm::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - original_generate_fallback_kernel_with_runtime_lookup_aot = ( - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args - ) - - def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( - self, - op_overload, - raw_args, - output_args, - raw_outputs, - ): - # Extract kernel name for collection - kernel_name = getattr(op_overload, "_name", str(op_overload)) - if kernel_name not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel_name) - - original_generate_fallback_kernel_with_runtime_lookup_aot( - self, op_overload, raw_args, output_args, raw_outputs - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - original_generate_fallback_kernel_with_runtime_lookup_aot - ) @final @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaBackend(BackendDetails): +class CudaBackend(AotiBackend, BackendDetails): """ CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices using the Executorch runtime. """ - @staticmethod - def preprocess( - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - # Move the edge_program from CPU to CUDA for aoti compile - cuda_edge_program = move_to_device_pass(edge_program, "cuda") - - # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int - ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) + @classmethod + def get_device_name(cls) -> str: + return "cuda" - cuda_edge_program = cuda_edge_program.run_decompositions( - cuda_decomposition_table - ) + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + return { + "at::_ops::_weight_int4pack_mm::call": None, + } - edge_program_module = cuda_edge_program.module() + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + return { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, + } - # Grab all input placeholders from the graph - user_input_names = cuda_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in cuda_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) + @classmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass""" + return [ReplaceEdgeOpWithTritonOpPass()] - options: dict[str, typing.Any] = { + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """ + Get AOTI compile options for CUDA backend. + Options may vary based on platform (Linux vs Windows). + """ + # Base options for all platforms + options: Dict[str, typing.Any] = { # Disable this to support sdpa decomposition # TODO(gasoonjia): remove it after pin bump to latest pytorch "loop_ordering_after_fusion": False, @@ -162,87 +83,34 @@ def preprocess( "max_autotune_conv_backends": "TRITON", } - with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( - [ - SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. - ] - ), torch.no_grad(): - # torch._logging.set_logs(post_grad_graphs=True) - # Here we should expect 1 so file and 1 weight blob in the same directory. - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Method {CudaBackend.method_name_from_compile_specs(compile_specs)} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" + # Parse compile_specs to check for platform + platform = "linux" + shim_library_path = None + for spec in compile_specs: + if spec.key == "platform": + platform = spec.value.decode("utf-8") + if spec.key == "shim_library_path": + shim_library_path = spec.value.decode("utf-8") + + # Add platform-specific options + if platform == "windows": + # For Windows, get default shim library path if not provided + if shim_library_path is None: + lib_dir = resources.files("executorch").joinpath("data/lib") + shim_library_path = str(lib_dir) + + options.update( + { + "aot_inductor.cross_target_platform": "windows", + "aot_inductor.aoti_shim_library": "aoti_cuda_shims", + "aot_inductor.aoti_shim_library_path": shim_library_path, + "aot_inductor.precompile_headers": False, + } ) + else: + # Linux platform + assert ( + shim_library_path is None + ), "shim_library_path should not be set for Linux" - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = CudaBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob" - ) - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Returns the compile spec representing the model compute precision, for additional details - please refer to the documentation for ``coremltools.precision``. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) + return options diff --git a/backends/cuda/runtime/aoti_cuda_shims.lib b/backends/cuda/runtime/aoti_cuda_shims.lib new file mode 100644 index 00000000000..bd6cc53bf07 Binary files /dev/null and b/backends/cuda/runtime/aoti_cuda_shims.lib differ diff --git a/backends/cuda/runtime/memory_tracker.h b/backends/cuda/runtime/memory_tracker.h new file mode 100644 index 00000000000..e09a96da6a6 --- /dev/null +++ b/backends/cuda/runtime/memory_tracker.h @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace executorch::backends::cuda { + +/** + * @class CudaMemoryTracker + * @brief Tracks CUDA memory usage and logs memory state at key points + * + * This class provides utilities to query and track CUDA memory usage, + * including peak memory usage and detailed memory state logging. + */ +class CudaMemoryTracker { + public: + /** + * @brief Constructor - initializes tracker and logs startup memory state + */ + CudaMemoryTracker() { + if (!query(&last_free_bytes_, &total_bytes_)) { + return; + } + available_ = true; + // Record the initial free bytes observed at startup. We'll use this as a + // baseline so reported "peak usage" reflects additional memory used + // since the tracker was created (instead of the absolute device usage, + // which may include other processes). + initial_free_bytes_ = last_free_bytes_; + min_free_bytes_ = last_free_bytes_; + log_state("startup", last_free_bytes_, total_bytes_); + } + + /** + * @brief Logs current memory state at a tagged checkpoint + * @param tag Descriptive tag for this memory sample (e.g., "after_load") + */ + void log_sample(const char* tag) { + if (!available_) { + return; + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (!query(&free_bytes, &total_bytes)) { + return; + } + min_free_bytes_ = std::min(min_free_bytes_, free_bytes); + total_bytes_ = total_bytes; + last_free_bytes_ = free_bytes; + log_state(tag, free_bytes, total_bytes); + } + + /** + * @brief Destructor - logs final memory state and peak usage summary + */ + ~CudaMemoryTracker() { + if (!available_) { + return; + } + size_t free_bytes = 0; + size_t total_bytes = 0; + if (!query(&free_bytes, &total_bytes)) { + return; + } + min_free_bytes_ = std::min(min_free_bytes_, free_bytes); + total_bytes_ = total_bytes; + last_free_bytes_ = free_bytes; + // Compute peak usage relative to the initial free baseline so that + // allocations by other processes present at startup are not attributed + // to this process. If for some reason initial_free_bytes_ was not set, + // fall back to absolute device usage. + double peak_mb = 0.0; + if (initial_free_bytes_ != std::numeric_limits::max()) { + size_t used_delta = 0; + if (initial_free_bytes_ > min_free_bytes_) { + used_delta = initial_free_bytes_ - min_free_bytes_; + } + peak_mb = static_cast(used_delta) / (1024.0 * 1024.0); + } else { + peak_mb = static_cast(total_bytes_ - min_free_bytes_) / + (1024.0 * 1024.0); + } + const double total_mb = + static_cast(total_bytes_) / (1024.0 * 1024.0); + ET_LOG( + Info, + "CUDA memory peak usage (since startup): %.2f MB, device total: %.2f MB", + peak_mb, + total_mb); + } + + private: + /** + * @brief Queries current CUDA memory info + * @param free_bytes Output parameter for free memory in bytes + * @param total_bytes Output parameter for total memory in bytes + * @return true if query succeeded, false otherwise + */ + bool query(size_t* free_bytes, size_t* total_bytes) { + cudaError_t err = cudaMemGetInfo(free_bytes, total_bytes); + if (err != cudaSuccess) { + if (!error_logged_) { + error_logged_ = true; + ET_LOG( + Error, + "cudaMemGetInfo failed with error: %s", + cudaGetErrorString(err)); + } + available_ = false; + return false; + } + return true; + } + + /** + * @brief Logs the current memory state + * @param tag Tag describing this log point + * @param free_bytes Current free memory in bytes + * @param total_bytes Current total memory in bytes + */ + void log_state(const char* tag, size_t free_bytes, size_t total_bytes) const { + const double used_mb = + static_cast(total_bytes - free_bytes) / (1024.0 * 1024.0); + const double free_mb = static_cast(free_bytes) / (1024.0 * 1024.0); + const double total_mb = + static_cast(total_bytes) / (1024.0 * 1024.0); + ET_LOG( + Info, + "CUDA memory (%s): used %.2f MB, free %.2f MB, total %.2f MB", + tag, + used_mb, + free_mb, + total_mb); + } + + bool available_{false}; + bool error_logged_{false}; + size_t last_free_bytes_{0}; + size_t total_bytes_{0}; + size_t min_free_bytes_{std::numeric_limits::max()}; + // Baseline free bytes observed at tracker construction. Used to compute + // peak usage attributable to this process since the tracker started. + size_t initial_free_bytes_{std::numeric_limits::max()}; + + public: + // Simple accessors to allow other components to read last-sampled values. + // These are safe to call after a successful log_sample() invocation. + uint64_t last_free_bytes() const { + return static_cast(last_free_bytes_); + } + uint64_t total_bytes() const { + return static_cast(total_bytes_); + } + uint64_t min_free_bytes() const { + return static_cast(min_free_bytes_); + } + uint64_t initial_free_bytes() const { + return static_cast(initial_free_bytes_); + } + double peak_usage_mb() const { + // Prefer peak relative to the initial free baseline; fall back to + // absolute device peak if baseline isn't available. + if (min_free_bytes_ == std::numeric_limits::max()) { + return 0.0; + } + if (initial_free_bytes_ != std::numeric_limits::max()) { + size_t used_delta = 0; + if (initial_free_bytes_ > min_free_bytes_) { + used_delta = initial_free_bytes_ - min_free_bytes_; + } + return static_cast(used_delta) / (1024.0 * 1024.0); + } + if (total_bytes_ == 0) { + return 0.0; + } + return static_cast(total_bytes_ - min_free_bytes_) / + (1024.0 * 1024.0); + } +}; + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/cuda_guard.h b/backends/cuda/runtime/shims/cuda_guard.h index f930f3df643..83fceabb98f 100644 --- a/backends/cuda/runtime/shims/cuda_guard.h +++ b/backends/cuda/runtime/shims/cuda_guard.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -33,9 +34,8 @@ using CUDAStreamGuardHandle = CUDAStreamGuard*; * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_create_cuda_guard( - int32_t device_index, - CUDAGuardHandle* ret_guard); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_create_cuda_guard(int32_t device_index, CUDAGuardHandle* ret_guard); /** * Deletes a CUDA device guard and frees its associated resources. @@ -44,7 +44,8 @@ AOTITorchError aoti_torch_create_cuda_guard( * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); /** * Sets the CUDA device to a new index for an existing guard. @@ -54,9 +55,8 @@ AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_cuda_guard_set_index( - CUDAGuardHandle guard, - int32_t device_index); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard, int32_t device_index); /** * Creates a CUDA stream guard that sets the current device and stream, @@ -69,7 +69,7 @@ AOTITorchError aoti_torch_cuda_guard_set_index( * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_create_cuda_stream_guard( +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_cuda_stream_guard( void* stream, int32_t device_index, CUDAStreamGuardHandle* ret_guard); @@ -81,7 +81,8 @@ AOTITorchError aoti_torch_create_cuda_stream_guard( * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); /** * Gets the current CUDA stream for a specified device. @@ -91,9 +92,8 @@ AOTITorchError aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_get_current_cuda_stream( - int32_t device_index, - void** ret_stream); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); } // extern "C" diff --git a/backends/cuda/runtime/shims/int4mm.h b/backends/cuda/runtime/shims/int4mm.h index 6bd2d9b3a79..87a9916b0aa 100644 --- a/backends/cuda/runtime/shims/int4mm.h +++ b/backends/cuda/runtime/shims/int4mm.h @@ -10,6 +10,7 @@ #include #include +#include namespace executorch::backends::cuda { @@ -69,7 +70,7 @@ extern "C" { * or invalid qGroupSize * - Error::Internal: CUDA kernel launch failure */ -AOTITorchError aoti_torch_cuda__weight_int4pack_mm( +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm( Tensor* self, Tensor* mat2, int64_t qGroupSize, diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 46b8d448a3a..aaaf3913381 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -682,6 +682,95 @@ AOTITorchError aoti_torch__reinterpret_tensor( return Error::Ok; } +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + orig_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: orig_handle is null"); + + ET_CHECK_OR_RETURN_ERROR( + new_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: new_handle is null"); + + // Get metadata from the original tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + int32_t dtype; + int32_t device_type; + int32_t device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_strides(orig_handle, &strides_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_type(orig_handle, &device_type)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_index(orig_handle, &device_index)); + + int64_t ndim = orig_handle->dim(); + + // Validate dtype + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Get the original data pointer from the source tensor + void* data_ptr = orig_handle->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes and strides to vectors + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor that shares the same memory as the original + // This is similar to PyTorch's Tensor copy constructor - creates a new + // tensor object that shares the same underlying storage + std::shared_ptr tensor = make_tensor( + sizes, // Same sizes as original + data_ptr, // Share the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, // Same strides as original + dtype_to_scalar_type(dtype) // Same dtype as original + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create new tensor handle"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *new_handle = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 7a8d4c3609b..1a89d8b782c 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -10,6 +10,7 @@ #include #include +#include #include namespace executorch::backends::cuda { @@ -43,7 +44,7 @@ extern "C" { * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_create_tensor_from_blob_v2( +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( void* data, int64_t ndim, const int64_t* sizes_ptr, @@ -71,7 +72,7 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_empty_strided( +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_empty_strided( int64_t ndim, const int64_t* sizes_ptr, const int64_t* strides_ptr, @@ -87,7 +88,7 @@ AOTITorchError aoti_torch_empty_strided( * @return AOTITorchError error code (Error::Ok on success, or an error code on * failure) */ -AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); /** * Creates a tensor view that reinterprets the same underlying memory with @@ -106,7 +107,7 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor); * * @return Error::Ok on success, appropriate error code on failure */ -AOTITorchError aoti_torch__reinterpret_tensor( +AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( Tensor* self, int64_t ndim, const int64_t* sizes_ptr, @@ -136,11 +137,36 @@ AOTITorchError aoti_torch__reinterpret_tensor( * - Error::MemoryAllocationFailed: failed to allocate temporary memory * - Error::Internal: CUDA operation failures */ -AOTITorchError +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); +/** + * Creates a new tensor handle from an existing one. + * + * This function creates a new tensor object that shares the same underlying + * memory as the original tensor. Similar to PyTorch's Tensor copy constructor, + * it creates a new handle/reference to the same data without performing a deep + * copy. + * + * The new tensor will: + * - Share the same memory/storage as the original tensor + * - Have the same shape, strides, and dtype as the original + * - Increment the reference count for the underlying memory (if owned) + * + * @param orig_handle Original tensor to create a new handle from (must not be + * null) + * @param new_handle Output pointer to store the new tensor handle (must not be + * null) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or invalid parameters + */ +AOTITorchError aoti_torch_new_tensor_handle( + Tensor* orig_handle, + Tensor** new_handle); + // Function to clear all tensors from internal storage -void clear_all_tensors(); +AOTI_SHIM_EXPORT void clear_all_tensors(); } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tensor_attribute.h b/backends/cuda/runtime/shims/tensor_attribute.h index 6b61b5bd3b8..683f270ccda 100644 --- a/backends/cuda/runtime/shims/tensor_attribute.h +++ b/backends/cuda/runtime/shims/tensor_attribute.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -24,12 +25,11 @@ extern "C" { using AOTITorchError = Error; // Device type functions for tensor attributes -AOTITorchError aoti_torch_get_device_type( - Tensor* tensor, - int32_t* ret_device_type); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type); // Device type constants -int32_t aoti_torch_device_type_cuda(); +AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cuda(); } // extern "C" diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 34a9d60582f..b274ecf3675 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -34,3 +34,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_copy_") cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") + cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp index 1b59fc1abdb..19fc4dad685 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_cuda__weight_int4pack_mm.cpp @@ -264,15 +264,6 @@ TEST_F(AOTITorchInt4MMTest, NullInputHandling) { EXPECT_EQ(error, Error::InvalidArgument) << "Should fail with null output pointer"; } - - // Test null output tensor (ret0 points to null) - { - Tensor* null_output = nullptr; - AOTITorchError error = aoti_torch_cuda__weight_int4pack_mm( - A, B, qGroupSize, qScaleAndZeros, &null_output); - EXPECT_EQ(error, Error::InvalidArgument) - << "Should fail with null output tensor"; - } } // Test with larger batch size diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_new_tensor_handle.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_new_tensor_handle.cpp new file mode 100644 index 00000000000..d123443cbfa --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_new_tensor_handle.cpp @@ -0,0 +1,560 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_new_tensor_handle tests +class AOTITorchNewTensorHandleTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors + Tensor* create_test_tensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic functionality of creating a new tensor handle +TEST_F(AOTITorchNewTensorHandleTest, BasicFunctionality) { + // Create an original tensor + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + // Create a new handle from the original tensor + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(new_tensor, nullptr); + + // Verify the new tensor has the same properties + EXPECT_EQ(new_tensor->dim(), orig_tensor->dim()); + EXPECT_EQ(new_tensor->size(0), orig_tensor->size(0)); + EXPECT_EQ(new_tensor->size(1), orig_tensor->size(1)); + EXPECT_EQ(new_tensor->numel(), orig_tensor->numel()); + + // Verify they share the same memory + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_tensor->mutable_data_ptr()); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating new handle from null tensor +TEST_F(AOTITorchNewTensorHandleTest, NullOriginalTensor) { + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(nullptr, &new_tensor); + + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test passing null pointer for new handle +TEST_F(AOTITorchNewTensorHandleTest, NullNewHandle) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, nullptr); + + EXPECT_EQ(error, Error::InvalidArgument); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// Test memory sharing between original and new tensor handle +TEST_F(AOTITorchNewTensorHandleTest, MemorySharing) { + // Create an original tensor + std::vector sizes = {3, 4}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + // Get original memory pointer + void* orig_ptr = orig_tensor->mutable_data_ptr(); + ASSERT_NE(orig_ptr, nullptr); + + // Create a new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify both tensors point to the same memory + void* new_ptr = new_tensor->mutable_data_ptr(); + EXPECT_EQ(orig_ptr, new_ptr); + + // Clean up - deleting one should not affect the other's validity + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // New tensor should still be valid and accessible + void* still_valid_ptr = new_tensor->mutable_data_ptr(); + EXPECT_EQ(still_valid_ptr, new_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating multiple handles from the same tensor +TEST_F(AOTITorchNewTensorHandleTest, MultipleHandles) { + // Create an original tensor + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create multiple handles + std::vector handles; + const int num_handles = 5; + + for (int i = 0; i < num_handles; i++) { + Tensor* new_tensor; + AOTITorchError error = + aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_ptr); + handles.push_back(new_tensor); + } + + // Delete original tensor + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // All handles should still be valid + for (Tensor* handle : handles) { + EXPECT_EQ(handle->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle->dim(), 2); + EXPECT_EQ(handle->size(0), 2); + EXPECT_EQ(handle->size(1), 3); + } + + // Delete all handles + for (Tensor* handle : handles) { + EXPECT_EQ(aoti_torch_delete_tensor_object(handle), Error::Ok); + } +} + +// Test creating handle from tensor with custom strides +TEST_F(AOTITorchNewTensorHandleTest, CustomStrides) { + std::vector sizes = {3, 4}; + std::vector strides = {4, 1}; // Row-major strides + Tensor* orig_tensor = create_test_tensor(sizes, strides); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify strides are preserved + int64_t* orig_strides_ptr; + int64_t* new_strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(orig_tensor, &orig_strides_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + + EXPECT_EQ(orig_strides_ptr[0], new_strides_ptr[0]); + EXPECT_EQ(orig_strides_ptr[1], new_strides_ptr[1]); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle from bfloat16 tensor +TEST_F(AOTITorchNewTensorHandleTest, BFloat16Tensor) { + std::vector sizes = {2, 3, 4}; + Tensor* orig_tensor = create_test_tensor( + sizes, + {}, + static_cast(SupportedDTypes::BFLOAT16), + static_cast(SupportedDevices::CUDA)); + ASSERT_NE(orig_tensor, nullptr); + + // Verify original is bfloat16 + int32_t orig_dtype; + EXPECT_EQ(aoti_torch_get_dtype(orig_tensor, &orig_dtype), Error::Ok); + EXPECT_EQ(orig_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify new tensor is also bfloat16 + int32_t new_dtype; + EXPECT_EQ(aoti_torch_get_dtype(new_tensor, &new_dtype), Error::Ok); + EXPECT_EQ(new_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Verify element size (bfloat16 should be 2 bytes) + EXPECT_EQ(new_tensor->element_size(), 2); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle from scalar (0D) tensor +TEST_F(AOTITorchNewTensorHandleTest, ScalarTensor) { + std::vector sizes = {}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_EQ(orig_tensor->dim(), 0); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify scalar properties + EXPECT_EQ(new_tensor->dim(), 0); + EXPECT_EQ(new_tensor->numel(), 1); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle from zero-sized tensor +TEST_F(AOTITorchNewTensorHandleTest, ZeroSizedTensor) { + std::vector sizes = {0, 5}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + EXPECT_EQ(orig_tensor->numel(), 0); + + // Attempt to create new handle - should fail because zero-sized tensors have + // null data pointers + Tensor* new_tensor = nullptr; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + + // Zero-sized tensors are not currently supported + EXPECT_EQ(error, Error::InvalidArgument); + EXPECT_EQ(new_tensor, nullptr); + + // Clean up original tensor + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// Test creating handle from large multi-dimensional tensor +TEST_F(AOTITorchNewTensorHandleTest, LargeMultiDimensionalTensor) { + std::vector sizes = {10, 20, 30}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify dimensions + EXPECT_EQ(new_tensor->dim(), 3); + EXPECT_EQ(new_tensor->size(0), 10); + EXPECT_EQ(new_tensor->size(1), 20); + EXPECT_EQ(new_tensor->size(2), 30); + EXPECT_EQ(new_tensor->numel(), 6000); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle preserves tensor metadata +TEST_F(AOTITorchNewTensorHandleTest, MetadataPreservation) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + Tensor* orig_tensor = create_test_tensor( + sizes, + strides, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Get and compare all metadata + int64_t* orig_sizes_ptr; + int64_t* new_sizes_ptr; + int64_t* orig_strides_ptr; + int64_t* new_strides_ptr; + int32_t orig_dtype, new_dtype; + int32_t orig_device_type, new_device_type; + int32_t orig_device_index, new_device_index; + + EXPECT_EQ(aoti_torch_get_sizes(orig_tensor, &orig_sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(new_tensor, &new_sizes_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(orig_tensor, &orig_strides_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + EXPECT_EQ(aoti_torch_get_dtype(orig_tensor, &orig_dtype), Error::Ok); + EXPECT_EQ(aoti_torch_get_dtype(new_tensor, &new_dtype), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_type(orig_tensor, &orig_device_type), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_type(new_tensor, &new_device_type), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_index(orig_tensor, &orig_device_index), Error::Ok); + EXPECT_EQ( + aoti_torch_get_device_index(new_tensor, &new_device_index), Error::Ok); + + // Verify all metadata matches + for (int i = 0; i < 3; i++) { + EXPECT_EQ(orig_sizes_ptr[i], new_sizes_ptr[i]); + EXPECT_EQ(orig_strides_ptr[i], new_strides_ptr[i]); + } + EXPECT_EQ(orig_dtype, new_dtype); + EXPECT_EQ(orig_device_type, new_device_type); + EXPECT_EQ(orig_device_index, new_device_index); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle chain: orig -> handle1 -> handle2 +TEST_F(AOTITorchNewTensorHandleTest, HandleChain) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create first handle + Tensor* handle1; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &handle1); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(handle1, nullptr); + EXPECT_EQ(handle1->mutable_data_ptr(), orig_ptr); + + // Create second handle from the first handle + Tensor* handle2; + error = aoti_torch_new_tensor_handle(handle1, &handle2); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(handle2, nullptr); + EXPECT_EQ(handle2->mutable_data_ptr(), orig_ptr); + + // Delete in reverse order + EXPECT_EQ(aoti_torch_delete_tensor_object(handle2), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(handle1), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); +} + +// Test creating handle and verifying reference counting +TEST_F(AOTITorchNewTensorHandleTest, ReferenceCountingTest) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create multiple handles + Tensor* handle1; + Tensor* handle2; + Tensor* handle3; + + EXPECT_EQ(aoti_torch_new_tensor_handle(orig_tensor, &handle1), Error::Ok); + EXPECT_EQ(aoti_torch_new_tensor_handle(orig_tensor, &handle2), Error::Ok); + EXPECT_EQ(aoti_torch_new_tensor_handle(orig_tensor, &handle3), Error::Ok); + + // Delete original + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // All handles should still be valid + EXPECT_EQ(handle1->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle2->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle3->mutable_data_ptr(), orig_ptr); + + // Delete handles one by one + EXPECT_EQ(aoti_torch_delete_tensor_object(handle1), Error::Ok); + + // Remaining handles should still be valid + EXPECT_EQ(handle2->mutable_data_ptr(), orig_ptr); + EXPECT_EQ(handle3->mutable_data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(handle2), Error::Ok); + + // Last handle should still be valid + EXPECT_EQ(handle3->mutable_data_ptr(), orig_ptr); + + EXPECT_EQ(aoti_torch_delete_tensor_object(handle3), Error::Ok); +} + +// Test creating handle from int32 tensor +TEST_F(AOTITorchNewTensorHandleTest, Int32Tensor) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor( + sizes, + {}, + 3, // int32 + static_cast(SupportedDevices::CUDA)); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify dtype + int32_t new_dtype; + EXPECT_EQ(aoti_torch_get_dtype(new_tensor, &new_dtype), Error::Ok); + EXPECT_EQ(new_dtype, 3); // int32 + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle with incontiguous tensor (transpose-like layout) +TEST_F(AOTITorchNewTensorHandleTest, IncontiguousTransposeLayout) { + std::vector sizes = {3, 4}; + std::vector strides = {1, 3}; // Column-major (incontiguous) + Tensor* orig_tensor = create_test_tensor(sizes, strides); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify strides are preserved + int64_t* new_strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + EXPECT_EQ(new_strides_ptr[0], 1); + EXPECT_EQ(new_strides_ptr[1], 3); + + // Verify both tensors share the same memory + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_tensor->mutable_data_ptr()); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Test creating handle with expanded strides (broadcasted dimension) +TEST_F(AOTITorchNewTensorHandleTest, ExpandedStrides) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {0, 4, 1}; // First dimension has stride 0 + Tensor* orig_tensor = create_test_tensor(sizes, strides); + ASSERT_NE(orig_tensor, nullptr); + + // Create new handle + Tensor* new_tensor; + AOTITorchError error = aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + + // Verify expanded strides are preserved + int64_t* new_strides_ptr; + EXPECT_EQ(aoti_torch_get_strides(new_tensor, &new_strides_ptr), Error::Ok); + EXPECT_EQ(new_strides_ptr[0], 0); + EXPECT_EQ(new_strides_ptr[1], 4); + EXPECT_EQ(new_strides_ptr[2], 1); + + // Clean up + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(new_tensor), Error::Ok); +} + +// Stress test: create many handles +TEST_F(AOTITorchNewTensorHandleTest, StressTestManyHandles) { + std::vector sizes = {2, 3}; + Tensor* orig_tensor = create_test_tensor(sizes); + ASSERT_NE(orig_tensor, nullptr); + + void* orig_ptr = orig_tensor->mutable_data_ptr(); + + // Create many handles + const int num_handles = 100; + std::vector handles; + + for (int i = 0; i < num_handles; i++) { + Tensor* new_tensor; + AOTITorchError error = + aoti_torch_new_tensor_handle(orig_tensor, &new_tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(new_tensor, nullptr); + EXPECT_EQ(new_tensor->mutable_data_ptr(), orig_ptr); + handles.push_back(new_tensor); + } + + // Delete original + EXPECT_EQ(aoti_torch_delete_tensor_object(orig_tensor), Error::Ok); + + // All handles should still be valid + for (Tensor* handle : handles) { + EXPECT_EQ(handle->mutable_data_ptr(), orig_ptr); + } + + // Delete all handles + for (Tensor* handle : handles) { + EXPECT_EQ(aoti_torch_delete_tensor_object(handle), Error::Ok); + } +} diff --git a/backends/cuda/tests/TARGETS b/backends/cuda/tests/TARGETS index 12718c04388..974086cd4c5 100644 --- a/backends/cuda/tests/TARGETS +++ b/backends/cuda/tests/TARGETS @@ -19,6 +19,7 @@ python_unittest_remote_gpu( "//executorch/exir:lib", "//executorch/exir/backend:backend_api", "//executorch/exir/backend:compile_spec_schema", + "//executorch/examples/models/toy_model:toy_model", ], keep_gpu_sections = True, ) diff --git a/backends/cuda/tests/multimodal_benchmark.cpp b/backends/cuda/tests/multimodal_benchmark.cpp deleted file mode 100644 index 7365d0b7ba8..00000000000 --- a/backends/cuda/tests/multimodal_benchmark.cpp +++ /dev/null @@ -1,466 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace { - -using executorch::aten::ScalarType; -using executorch::aten::Tensor; -using executorch::extension::make_tensor_ptr; -using executorch::extension::TensorPtr; -using executorch::extension::module::Module; -using executorch::runtime::Error; -using executorch::runtime::EValue; -using executorch::runtime::Result; -using Clock = std::chrono::steady_clock; -using executorch::aten::TensorShapeDynamism; -using DurationMs = std::chrono::duration; - -enum class ModelType { GEMMA3, VOXTRAL, UNKNOWN }; - -struct ModelConfig { - std::string name; - size_t token_seq_len; - size_t text_embed_dim; - std::vector expected_methods; -}; - -const std::map model_configs = { - {ModelType::GEMMA3, - {"gemma3", - 128, - 2304, - {"vision_encoder", "token_embedding", "text_decoder"}}}, - {ModelType::VOXTRAL, - {"voxtral", - 1138, - 3072, - {"audio_encoder", "token_embedding", "text_decoder"}}}}; - -ModelType parse_model_type(const std::string& model_name) { - std::string lower_name = model_name; - std::transform( - lower_name.begin(), - lower_name.end(), - lower_name.begin(), - [](unsigned char c) { return std::tolower(c); }); - - if (lower_name.find("gemma3") != std::string::npos || - lower_name.find("gemma-3") != std::string::npos) { - return ModelType::GEMMA3; - } else if (lower_name.find("voxtral") != std::string::npos) { - return ModelType::VOXTRAL; - } - return ModelType::UNKNOWN; -} - -std::vector to_sizes( - std::initializer_list dims) { - return std::vector(dims.begin(), dims.end()); -} - -std::string format_shape(const Tensor& tensor) { - std::ostringstream oss; - oss << "["; - const auto& sizes = tensor.sizes(); - for (size_t i = 0; i < sizes.size(); ++i) { - if (i > 0) { - oss << ", "; - } - oss << sizes[i]; - } - oss << "]"; - return oss.str(); -} - -void print_tensor_summary(const std::string& label, const Tensor& tensor) { - std::cout << " " << label - << ": dtype=" << executorch::runtime::toString(tensor.scalar_type()) - << ", shape=" << format_shape(tensor) - << ", numel=" << tensor.numel() << std::endl; -} - -void dump_tensor_to_file(const std::string& filename, const Tensor& tensor) { - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) { - std::cerr << "Failed to open file for writing: " << filename << std::endl; - return; - } - - int32_t dtype = static_cast(tensor.scalar_type()); - file.write(reinterpret_cast(&dtype), sizeof(int32_t)); - - int32_t ndim = static_cast(tensor.sizes().size()); - file.write(reinterpret_cast(&ndim), sizeof(int32_t)); - - for (size_t i = 0; i < tensor.sizes().size(); ++i) { - int64_t dim_size = tensor.sizes()[i]; - file.write(reinterpret_cast(&dim_size), sizeof(int64_t)); - } - - const void* data_ptr = tensor.const_data_ptr(); - size_t element_size = 0; - - switch (tensor.scalar_type()) { - case ScalarType::Float: - element_size = sizeof(float); - break; - case ScalarType::BFloat16: - element_size = 2; - break; - case ScalarType::Half: - element_size = 2; - break; - case ScalarType::Long: - element_size = sizeof(int64_t); - break; - case ScalarType::Int: - element_size = sizeof(int32_t); - break; - default: - std::cerr << "Unsupported dtype for dumping: " - << executorch::runtime::toString(tensor.scalar_type()) - << std::endl; - return; - } - - size_t data_size = tensor.numel() * element_size; - file.write(reinterpret_cast(data_ptr), data_size); - file.close(); - - std::cout << "Dumped tensor to: " << filename << std::endl; -} - -TensorPtr create_vision_input() { - const auto sizes = to_sizes({1, 3, 896, 896}); - const size_t numel = 1ull * 3ull * 896ull * 896ull; - std::vector data(numel); - for (size_t i = 0; i < numel; ++i) { - data[i] = static_cast((i % 255) / 255.0); - } - return make_tensor_ptr( - sizes, - std::move(data), - {}, - {}, - ScalarType::BFloat16, - TensorShapeDynamism::DYNAMIC_UNBOUND); -} - -TensorPtr create_audio_input() { - const auto sizes = to_sizes({3, 128, 3000}); - const size_t numel = 3ull * 128ull * 3000ull; - std::vector data(numel, 0.5f); - return make_tensor_ptr( - sizes, std::move(data), {}, {}, ScalarType::BFloat16); -} - -TensorPtr create_token_ids_input(const ModelConfig& config) { - const auto sizes = to_sizes({1, static_cast(config.token_seq_len)}); - std::vector data(config.token_seq_len); - for (size_t i = 0; i < config.token_seq_len; ++i) { - data[i] = static_cast(i + 1); - } - return make_tensor_ptr(sizes, std::move(data)); -} - -TensorPtr create_positions_input(const ModelConfig& config) { - const auto sizes = to_sizes({static_cast(config.token_seq_len)}); - std::vector data(config.token_seq_len); - for (size_t i = 0; i < config.token_seq_len; ++i) { - data[i] = static_cast(i); - } - return make_tensor_ptr(sizes, std::move(data)); -} - -TensorPtr create_fallback_text_embedding(const ModelConfig& config) { - const auto sizes = to_sizes( - {1, - static_cast(config.token_seq_len), - static_cast(config.text_embed_dim)}); - const size_t numel = 1ull * config.token_seq_len * config.text_embed_dim; - std::vector data(numel, 0.0f); - return make_tensor_ptr( - sizes, std::move(data), {}, {}, ScalarType::BFloat16); -} - -struct MethodTiming { - double load_ms{0.0}; - double run_ms{0.0}; -}; - -enum class MethodCategory { ENCODER, TOKEN_EMBEDDING, TEXT_DECODER, UNKNOWN }; - -MethodCategory categorize_method(const std::string& method_name) { - std::string lower_name = method_name; - std::transform( - lower_name.begin(), - lower_name.end(), - lower_name.begin(), - [](unsigned char c) { return std::tolower(c); }); - - if (lower_name.find("vision") != std::string::npos || - lower_name.find("audio") != std::string::npos || - lower_name.find("encoder") != std::string::npos) { - return MethodCategory::ENCODER; - } else if ( - lower_name.find("token") != std::string::npos && - lower_name.find("embedding") != std::string::npos) { - return MethodCategory::TOKEN_EMBEDDING; - } else if ( - lower_name.find("text") != std::string::npos && - lower_name.find("decoder") != std::string::npos) { - return MethodCategory::TEXT_DECODER; - } - return MethodCategory::UNKNOWN; -} - -std::vector create_inputs_for_method( - const std::string& method_name, - MethodCategory category, - ModelType model_type, - const ModelConfig& config, - const EValue* token_output, - std::vector& owned_inputs) { - std::vector inputs; - - switch (category) { - case MethodCategory::ENCODER: { - if (method_name.find("vision") != std::string::npos) { - auto input = create_vision_input(); - owned_inputs.emplace_back(input); - inputs.emplace_back(*input); - } else if (method_name.find("audio") != std::string::npos) { - auto input = create_audio_input(); - owned_inputs.emplace_back(input); - inputs.emplace_back(*input); - } - break; - } - - case MethodCategory::TOKEN_EMBEDDING: { - auto token_ids = create_token_ids_input(config); - owned_inputs.emplace_back(token_ids); - inputs.emplace_back(*token_ids); - break; - } - - case MethodCategory::TEXT_DECODER: { - if (token_output && token_output->isTensor()) { - inputs.emplace_back(*token_output); - } else { - auto fallback_embedding = create_fallback_text_embedding(config); - owned_inputs.emplace_back(fallback_embedding); - inputs.emplace_back(*fallback_embedding); - } - - auto positions = create_positions_input(config); - owned_inputs.emplace_back(positions); - inputs.emplace_back(*positions); - break; - } - - default: - break; - } - - return inputs; -} - -Error execute_method( - Module& module, - const std::string& method_name, - MethodCategory category, - ModelType model_type, - const ModelConfig& config, - const EValue* token_output, - MethodTiming& timing, - EValue* output_storage = nullptr) { - ET_LOG(Info, "Loading %s...", method_name.c_str()); - - const auto load_start = Clock::now(); - const Error load_err = module.load_method(method_name); - const auto load_end = Clock::now(); - if (load_err != Error::Ok) { - std::cerr << "Failed to load method " << method_name << ": error code " - << static_cast(load_err) << std::endl; - return load_err; - } - timing.load_ms = DurationMs(load_end - load_start).count(); - - std::vector owned_inputs; - std::vector inputs = create_inputs_for_method( - method_name, category, model_type, config, token_output, owned_inputs); - - const auto run_start = Clock::now(); - ET_LOG(Info, "%s running", method_name.c_str()); - Result> output_result = - module.execute(method_name, inputs); - ET_LOG(Info, "%s done", method_name.c_str()); - const auto run_end = Clock::now(); - timing.run_ms = DurationMs(run_end - run_start).count(); - - if (output_result.error() != Error::Ok) { - std::cerr << method_name << " execution failed: error code " - << static_cast(output_result.error()) << std::endl; - return output_result.error(); - } - - const auto& outputs = output_result.get(); - if (!outputs.empty() && outputs[0].isTensor()) { - print_tensor_summary(method_name + " output", outputs[0].toTensor()); - - if (category == MethodCategory::ENCODER || - category == MethodCategory::TOKEN_EMBEDDING) { - dump_tensor_to_file(method_name + "_output.bin", outputs[0].toTensor()); - } - - if (output_storage) { - *output_storage = outputs[0]; - } - } - - return Error::Ok; -} - -} // namespace - -int main(int argc, char** argv) { - if (argc != 4) { - std::cerr - << "Usage: " << argv[0] - << " " - << std::endl; - std::cerr << " model_name: gemma3 or voxtral" << std::endl; - return 1; - } - - const std::string model_name = argv[1]; - const std::string program_path = argv[2]; - const std::string data_map_path = argv[3]; - - const ModelType model_type = parse_model_type(model_name); - if (model_type == ModelType::UNKNOWN) { - std::cerr << "Unknown model type: " << model_name << std::endl; - std::cerr << "Supported models: gemma3, voxtral" << std::endl; - return 1; - } - - const ModelConfig& config = model_configs.at(model_type); - std::cout << "Running benchmark for model: " << config.name << std::endl; - - try { - Module module(program_path, data_map_path); - - const auto program_load_start = Clock::now(); - const Error program_load_error = module.load(); - const auto program_load_end = Clock::now(); - if (program_load_error != Error::Ok) { - std::cerr << "Failed to load ExecuTorch program: error code " - << static_cast(program_load_error) << std::endl; - return 1; - } - const DurationMs program_load_latency = - program_load_end - program_load_start; - - auto method_names_result = module.method_names(); - if (method_names_result.error() != Error::Ok) { - std::cerr << "Failed to get method names: error code " - << static_cast(method_names_result.error()) << std::endl; - return 1; - } - - const auto& available_methods = method_names_result.get(); - - std::cout << "Checking for expected methods..." << std::endl; - std::vector missing_methods; - for (const auto& expected : config.expected_methods) { - if (available_methods.find(expected) == available_methods.end()) { - missing_methods.push_back(expected); - } else { - std::cout << " ✓ " << expected << std::endl; - } - } - - if (!missing_methods.empty()) { - std::cerr << "\nError: Missing expected methods:" << std::endl; - for (const auto& missing : missing_methods) { - std::cerr << " ✗ " << missing << std::endl; - } - return 1; - } - - std::map timings; - EValue token_output; - bool token_executed = false; - - for (const auto& method_name : config.expected_methods) { - MethodCategory category = categorize_method(method_name); - MethodTiming timing; - - const EValue* input_token_ptr = - (category == MethodCategory::TEXT_DECODER && token_executed) - ? &token_output - : nullptr; - - EValue* output_storage = (category == MethodCategory::TOKEN_EMBEDDING) - ? &token_output - : nullptr; - - Error err = execute_method( - module, - method_name, - category, - model_type, - config, - input_token_ptr, - timing, - output_storage); - - if (err != Error::Ok) { - return 1; - } - - if (category == MethodCategory::TOKEN_EMBEDDING) { - token_executed = true; - } - - timings[method_name] = timing; - } - - std::cout << std::fixed << std::setprecision(3); - std::cout << "\n=== Benchmark Results ===" << std::endl; - std::cout << "Program load latency (ms): " << program_load_latency.count() - << std::endl; - - std::cout << "\nMethod load latency (ms):" << std::endl; - for (const auto& [name, timing] : timings) { - std::cout << " " << name << ": " << timing.load_ms << std::endl; - } - - std::cout << "\nRun latency (ms):" << std::endl; - for (const auto& [name, timing] : timings) { - std::cout << " " << name << ": " << timing.run_ms << std::endl; - } - - return 0; - } catch (const std::exception& ex) { - std::cerr << "Unhandled exception: " << ex.what() << std::endl; - return 1; - } -} diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py index ef43a3ab3cb..03f4e4a9602 100644 --- a/backends/cuda/tests/test_cuda_export.py +++ b/backends/cuda/tests/test_cuda_export.py @@ -10,6 +10,7 @@ import torch from executorch.backends.cuda.cuda_backend import CudaBackend from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.examples.models.toy_model import SdpaModule from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from torch.export import export @@ -270,3 +271,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Test export edge_program_manager = self._export_to_cuda_with_lower(module, inputs) self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed") + + def test_sdpa_single_kernel(self): + """ + Test CUDA export for model containing single SDPA kernel. + SDPA: Scaled Dot Product Attention + """ + + sdpa = SdpaModule() + + # Test export + edge_program_manager = self._export_to_cuda_with_lower( + sdpa.get_eager_model(), sdpa.get_example_inputs() + ) + self.assertIsNotNone( + edge_program_manager, + "SDPA single kernel operation export failed", + ) diff --git a/backends/cuda/triton/__init__.py b/backends/cuda/triton/__init__.py new file mode 100644 index 00000000000..4b9c36249ac --- /dev/null +++ b/backends/cuda/triton/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Import all kernels to ensure @triton_op decorators are executed +# and ops are registered to torch.ops.triton namespace +from executorch.backends.cuda.triton import kernels # noqa: F401 + +from executorch.backends.cuda.triton.replacement_pass import ( + ReplaceEdgeOpWithTritonOpPass, +) + +__all__ = [ + "ReplaceEdgeOpWithTritonOpPass", +] diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py new file mode 100644 index 00000000000..5bd582679c4 --- /dev/null +++ b/backends/cuda/triton/kernels/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.cuda.triton.kernels.sdpa import sdpa + +__all__ = [ + "sdpa", +] diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py new file mode 100644 index 00000000000..7e8eb1444df --- /dev/null +++ b/backends/cuda/triton/kernels/sdpa.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Triton SDPA Kernel for ExecuTorch CUDA Backend. + +This module provides a Triton-optimized implementation of scaled dot-product attention +that can replace the default ATen/Edge SDPA operator during graph transformation to allow +us export the model without decomposing the SDPA operator under libtorch free environment +and have better performance. +""" + +import math +from typing import Optional + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + + +def _next_power_of_2(n: int) -> int: + """Round up to the next power of 2.""" + if n <= 0: + return 1 + if n & (n - 1) == 0: + return n + + power = 1 + while power < n: + power <<= 1 + return power + + +def _validate_qkv_shapes( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> tuple[int, int, int, int, int, int]: + """ + Validate dimensions and return shape info. + Args: + query: Query tensor [B, H, L_q, D] + key: Key tensor [B, H, L_kv, D] + value: Value tensor [B, H, L_kv, D] + Returns: + Tuple of (B, H, L_q, L_kv, D_q, D_kv) + Raises: + RuntimeError: If dimensions are incompatible + """ + B_q, H_q, L_q, D_q = query.shape + B_k, H_k, L_kv_k, D_k = key.shape + B_v, H_v, L_kv_v, D_v = value.shape + # Validate batch and head dimensions + if not (B_q == B_k == B_v): + raise RuntimeError( + f"Batch dimension must match; got B_q={B_q}, B_k={B_k}, B_v={B_v}." + ) + + if not (H_q == H_k == H_v): + raise RuntimeError( + f"Head dimension must match; got H_q={H_q}, H_k={H_k}, H_v={H_v}." + ) + # Head dimension must match + if not (D_q == D_k == D_v): + raise RuntimeError( + f"Head dimension must match across Q, K, V; got D_q={D_q}, D_k={D_k}, D_v={D_v}." + ) + # Key and Value sequence lengths must match + if L_kv_k != L_kv_v: + raise RuntimeError( + f"Key and Value must have the same sequence length; got L_k={L_kv_k}, L_v={L_kv_v}." + ) + return B_q, H_q, L_q, L_kv_k, D_q, D_k + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=8), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_stages=4, num_warps=8), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=1, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=1, num_warps=2), + ], + key=["L_Q", "L_KV", "HEAD_DIM"], +) +@triton.jit +def _sdpa_fwd_kernel( + q_ptr, + k_ptr, + v_ptr, + mask_ptr, + o_ptr, + B, + H, + L_Q, # Query sequence length + L_KV, # Key/Value sequence length + HEAD_DIM, # Actual head dimension (may not be power of 2) + stride_qb, + stride_qh, + stride_ql, + stride_qd, + stride_kb, + stride_kh, + stride_kl, + stride_kd, + stride_vb, + stride_vh, + stride_vl, + stride_vd, + stride_mb, + stride_mh, + stride_ml, + stride_mn, + stride_ob, + stride_oh, + stride_ol, + stride_od, + sm_scale, + IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM_CE: tl.constexpr, # Rounded up for tl.arange +): + """ + Fused SDPA kernel that handles different sequence lengths for Q and K/V. + + Q shape: [B, H, L_Q, D] + K/V shape: [B, H, L_KV, D] + Output shape: [B, H, L_Q, D] + """ + # Program IDs + pid_m = tl.program_id(axis=0) # along query length + pid_hz = tl.program_id(axis=1) # flattened batch*head + off_b = pid_hz // H + off_h = pid_hz % H + # Compute ranges for queries + start_m = pid_m * BLOCK_M + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_DIM_CE) + mask_m = offs_m < L_Q # Mask based on query length + # Base pointers for this (b, h) + q_base = q_ptr + off_b * stride_qb + off_h * stride_qh + k_base = k_ptr + off_b * stride_kb + off_h * stride_kh + v_base = v_ptr + off_b * stride_vb + off_h * stride_vh + o_base = o_ptr + off_b * stride_ob + off_h * stride_oh + # Mask base pointer (if provided) + if HAS_MASK: + mask_base = mask_ptr + off_b * stride_mb + off_h * stride_mh + # Mask for actual head dimension (HEAD_DIM may not be power of 2) + mask_d = offs_d < HEAD_DIM + # Make head-dim addresses compiler-friendly + offs_d_ctg = tl.max_contiguous(tl.multiple_of(offs_d, 16), HEAD_DIM_CE) + # Load Q tile [BLOCK_M, HEAD_DIM] - coalesced along HEAD_DIM + q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d_ctg[None, :] * stride_qd) + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + q = q.to(tl.bfloat16) + # Initialize accumulators and softmax stats + acc = tl.zeros((BLOCK_M, HEAD_DIM_CE), dtype=tl.float32) + m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) + l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) + # Convert to base-2 scale for exp2 + qk_scale = sm_scale * 1.4426950408889634 + # Loop over keys/values along L_KV dimension (not L_Q!) + for start_n in tl.range(0, L_KV, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < L_KV # Mask based on key/value length + # Load K tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) + k_ptrs = k_base + ( + offs_n[:, None] * stride_kl + offs_d_ctg[None, :] * stride_kd + ) + k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + k = k.to(tl.bfloat16) + # Compute attention logits [BLOCK_M, BLOCK_N] = Q[BM,D] @ K[BN,D]^T + qk = tl.dot(q, tl.trans(k)).to(tl.float32) + qk = qk * qk_scale + # Apply causal mask if needed + # For causal masking with different lengths: position i can attend to position j if i >= j + if IS_CAUSAL: + causal_mask = offs_m[:, None] >= offs_n[None, :] + qk = tl.where(causal_mask, qk, -float("inf")) + # Apply attention mask if provided + if HAS_MASK: + # Load mask tile [BLOCK_M, BLOCK_N] + # Mask shape should be [B, H, L_Q, L_KV] + mask_ptrs = mask_base + ( + offs_m[:, None] * stride_ml + offs_n[None, :] * stride_mn + ) + attn_mask = tl.load( + mask_ptrs, + mask=mask_m[:, None] & mask_n[None, :], + other=0.0, + ) + # Convert boolean mask to additive mask (-inf for False, 0 for True) + qk = tl.where(attn_mask, qk, -float("inf")) + # Apply OOB masks for both rows and cols + qk = tl.where(mask_n[None, :], qk, -float("inf")) + qk = tl.where(mask_m[:, None], qk, -float("inf")) + # Online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + # Load V tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) + v_ptrs = v_base + ( + offs_n[:, None] * stride_vl + offs_d_ctg[None, :] * stride_vd + ) + v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) + v = v.to(tl.bfloat16) + # Update accumulator + acc = acc * alpha[:, None] + p_bf16 = p.to(tl.bfloat16) + acc = tl.dot(p_bf16, v, acc) + # Update softmax stats + l_i = l_i * alpha + l_ij + m_i = m_ij + # Normalize accumulator by softmax denominator + acc = acc / l_i[:, None] + # Store output [BLOCK_M, HEAD_DIM] - shape matches query + o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d_ctg[None, :] * stride_od) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) + + +@triton_op("triton::sdpa", mutates_args={}) +def sdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + """ + Triton fused Scaled Dot-Product Attention with support for different sequence lengths. + + Args: + query: Query tensor with szie [B, H, L_q, D] and dtype torch.bfloat16 + key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 + value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 + attn_mask: Optional attention mask [B, H, L_q, L_kv] or + broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv]) + dropout_p: must be 0.0 (others are not supported) + is_causal: whether to apply causal masking + scale: attention scale (default: 1/sqrt(D)) + enable_gqa: must be False (True is not supported) + Returns: + Output tensor [B, H, L_q, D] with dtype torch.bfloat16 + """ + # Validate inputs + if not (query.is_cuda and key.is_cuda and value.is_cuda): + raise RuntimeError("Q, K, V must be CUDA tensors.") + if ( + query.dtype != torch.bfloat16 + or key.dtype != torch.bfloat16 + or value.dtype != torch.bfloat16 + ): + raise RuntimeError("Expected bfloat16 inputs") + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise RuntimeError( + f"Expected 4D tensors shaped [B, H, L, D]; got " + f"query.dim()={query.dim()}, key.dim()={key.dim()}, " + f"value.dim()={value.dim()}." + ) + # Enforce unsupported features + if dropout_p != 0.0: + raise RuntimeError( + "dropout_p must be 0.0 (not supported in this implementation)." + ) + if enable_gqa is not False: + raise RuntimeError( + "enable_gqa must be False (not supported in this implementation)." + ) + # Validate and get dimensions + B, H, L_q, L_kv, D_q, D_kv = _validate_qkv_shapes(query, key, value) + D = D_q # Head dimension + # Allocate output with query shape + out = torch.empty_like(query) + # Element-wise strides + sqb, sqh, sql, sqd = query.stride() + skb, skh, skl, skd = key.stride() + svb, svh, svl, svd = value.stride() + sob, soh, sol, sod = out.stride() + + # Grid: tile queries (M) and batch*heads axis + def grid(META): + return ( + triton.cdiv(L_q, META["BLOCK_M"]), # Based on query length + B * H, + ) + + # Scale factor for SDPA + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + # Handle attention mask + has_mask = attn_mask is not None + if has_mask: + # Expand mask to [B, H, L_q, L_kv] if needed + if attn_mask.dim() == 2: + # [L_q, L_kv] -> [B, H, L_q, L_kv] + attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).expand(B, H, -1, -1) + elif attn_mask.dim() == 3: + # [B, L_q, L_kv] -> [B, H, L_q, L_kv] + attn_mask = attn_mask.unsqueeze(1).expand(-1, H, -1, -1) + + # Validate mask shape + if attn_mask.shape != (B, H, L_q, L_kv): + # Try to expand if broadcastable + attn_mask = attn_mask.expand(B, H, L_q, L_kv) + + smb, smh, sml, smn = attn_mask.stride() + else: + # Dummy strides and mask + smb, smh, sml, smn = 0, 0, 0, 0 + attn_mask = torch.empty(0, dtype=torch.bool, device=query.device) + # Round up head dimension to next power of 2 for tile.arange in Triton kernel + HEAD_DIM_CE = _next_power_of_2(D) + # Launch kernel + wrap_triton(_sdpa_fwd_kernel)[grid]( + query, + key, + value, + attn_mask, + out, + B, + H, + L_q, # Query sequence length + L_kv, # Key/Value sequence length + D, # Actual head dimension + sqb, + sqh, + sql, + sqd, + skb, + skh, + skl, + skd, + svb, + svh, + svl, + svd, + smb, + smh, + sml, + smn, + sob, + soh, + sol, + sod, + sm_scale, + IS_CAUSAL=is_causal, + HAS_MASK=has_mask, + HEAD_DIM_CE=HEAD_DIM_CE, # Rounded to power of 2 + ) + return out + + +# Register the abstract/fake implementation for torch.export +# This is critical to avoid accessing real tensor data during export +@sdpa.register_fake +def _sdpa_abstract( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gq: bool = False, +) -> torch.Tensor: + """ + Abstract/fake implementation for torch.export. + This just returns an empty tensor with the correct shape/dtype/device. + """ + # Validate dtypes match + assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" + # Validate kqv's shape and get the output shape + B, H, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value) + + return torch.empty(B, H, L_q, D_q, dtype=query.dtype, device=query.device) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py new file mode 100644 index 00000000000..bfa3838296b --- /dev/null +++ b/backends/cuda/triton/replacement_pass.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Graph Transformation Pass for Triton Kernel Replacement. + +This pass replaces ATen operators with optimized Triton kernels in the graph. +""" + +import logging + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +logger = logging.getLogger(__name__) +triton = torch.ops.triton + +# Global mapping from edge dialect operators to Triton kernel functions +EDGE_TO_TRITON_KERNELS = { + exir_ops.edge.aten.scaled_dot_product_attention.default: triton.sdpa, +} + + +class ReplaceEdgeOpWithTritonOpPass(PassBase): + """ + Pass to replace ATen operators with Triton kernels. + + This pass scans the graph for Edge operators that have registered Triton + replacements using EDGE_TO_TRITON_KERNELS and replaces them with the + optimized Triton implementations. + """ + + def __init__(self): + """Initialize the pass.""" + super().__init__() + self._replacement_count = 0 + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Execute the pass on the graph module. + + Args: + graph_module: The graph module to transform + + Returns: + PassResult indicating success/failure and the modified graph module + """ + self._replacement_count = 0 + modified = False + + if not EDGE_TO_TRITON_KERNELS: + return PassResult(graph_module, False) + + # Iterate through all nodes in the graph + for node in graph_module.graph.nodes: + if self._should_replace_node(node): + try: + self._replace_node_with_triton(graph_module, node) + modified = True + self._replacement_count += 1 + except Exception as e: + logger.warning(f"Failed to replace node {node.name}: {e}") + # Continue with other replacements even if one fails + + if modified: + # Recompile the graph module after modifications + graph_module.recompile() + + # logger.info(f"Replaced {self._replacement_count} nodes with Triton kernels") + print(f"Replaced {self._replacement_count} nodes with Triton kernels") + + return PassResult(graph_module, modified) + + def _should_replace_node(self, node: Node) -> bool: + """ + Check if a node should be replaced with a Triton kernel. + + Args: + node: The node to check + + Returns: + True if the node should be replaced + """ + # Only consider call_function nodes + if node.op != "call_function": + return False + + return node.target in EDGE_TO_TRITON_KERNELS + + def _replace_node_with_triton(self, graph_module: GraphModule, node: Node) -> None: + """ + Replace an edge dialect node with a Triton kernel call. + + Args: + graph_module: The graph module containing the node + node: The node to replace + """ + # Get the target operator (should be an exir_ops edge dialect op) + target = node.target + + # Get the replacement kernel + if target not in EDGE_TO_TRITON_KERNELS: + raise ValueError(f"No replacement kernel found for {target}") + + triton_kernel_fn = EDGE_TO_TRITON_KERNELS[target] + + # Create a new node with the Triton kernel + with graph_module.graph.inserting_before(node): + # The triton_kernel_fn is already registered as a custom op via @triton_op + # We can call it directly + new_node = graph_module.graph.call_function( + triton_kernel_fn, + args=node.args, + kwargs=node.kwargs, + ) + + # Copy metadata from original node + new_node.meta = node.meta.copy() + + # Replace all uses of the old node with the new node + node.replace_all_uses_with(new_node) + + # Remove the old node + graph_module.graph.erase_node(node) diff --git a/backends/nxp/aten_passes/move_activation_before_concat.py b/backends/nxp/aten_passes/move_activation_before_concat.py new file mode 100644 index 00000000000..8ba306d42e2 --- /dev/null +++ b/backends/nxp/aten_passes/move_activation_before_concat.py @@ -0,0 +1,102 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class MoveActivationBeforeConcat(PassBase): + """Move some operators around in the following pattern. + This is a common pattern that emerges from the conversion of separable convolutions. + This optimization works together with joint quantization of compute nodes and activations. Without it, + it is not beneficial. + + │ │ │ │ + ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ + │ aten.conv2d │ ... │ aten.conv2d │ │ aten.conv2d │ ... │ aten.conv2d │ + └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ + └───────┐ ┌──────┘ │ │ + ┌──▼─────▼─┐ replace with ┌─────▼─────┐ ┌─────▼─────┐ + │ aten.cat │ ──────────────► │ aten.relu │ ... │ aten.relu │ + └────┬─────┘ └─────┬─────┘ └─────┬─────┘ + │ └───────┐ ┌───────┘ + ┌─────▼─────┐ ┌──▼─────▼─┐ + │ aten.relu │ │ aten.cat │ + └─────┬─────┘ └────┬─────┘ + │ │ + """ + + def __init__(self, neutron_target_spec: NeutronTargetSpec): + self.neutron_target_spec = neutron_target_spec + + def call(self, module: GraphModule) -> bool: + def _is_concat(node_: Node) -> bool: + return ( + node_.op == "call_function" + and node_.target == torch.ops.aten.cat.default + ) + + made_changes = False + + for node in module.graph.nodes: + if not _is_concat(node): + continue # Not cat node. + + cat_node = node + activation = next(iter(cat_node.users)) + + # Check if all cat inputs nodes are conv 2D or linear 2D type and their only user is cat. + if not all( + self.neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ): + continue + + # Check if following activation is supported on Neutron as fused activation. + if not ( + len(cat_node.users) == 1 + and self.neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + activation + ) + ): + continue + + # Loop all Cat input nodes and insert new activation after node. + for input_node in cat_node.all_input_nodes: + with module.graph.inserting_after(input_node): + new_activation = module.graph.call_function( + activation.target, + args=(), + kwargs=activation.kwargs, + ) + + new_activation.meta["source_fn_stack"] = [ + ( + new_activation.name, + activation.meta["source_fn_stack"][-1][-1], + ) + ] + new_activation.meta["val"] = input_node.meta["val"] + + # Replace the uses of the input node with the new activation node. + input_node.replace_all_uses_with(new_activation) + new_activation.args = (input_node, *activation.args[1:]) + + # Replace the uses of the activation node with the cat node. + activation.replace_all_uses_with(cat_node) + + module.graph.erase_node(activation) + + made_changes = True + + return PassResult(module, made_changes) diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 407ebf5da61..35205c76c68 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -16,6 +16,9 @@ from executorch.backends.nxp.aten_passes.fuse_linear_and_add_pass import ( FuseLinearAndAddPass, ) +from executorch.backends.nxp.aten_passes.move_activation_before_concat import ( + MoveActivationBeforeConcat, +) from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import ( RemoveNodesWithKnownOutputs, ) @@ -25,6 +28,7 @@ from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( SplitGRUBasedOnNumLayers, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.exir.pass_manager import PassManager from torch import nn from torch.fx.passes.infra.pass_base import PassResult @@ -34,7 +38,9 @@ class NeutronAtenPassManager(PassManager): - def __init__(self, passes: list[PassType] = None): + def __init__( + self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None + ): passes: list[PassType] = passes or [ FuseBatchNormWithConvPass(), FuseBatchNormWithLinearPass(), @@ -42,6 +48,7 @@ def __init__(self, passes: list[PassType] = None): SplitGRUBasedOnNumLayers(), RemoveNodesWithKnownOutputs(), FuseLinearAndAddPass(), + MoveActivationBeforeConcat(neutron_target_spec), ] super().__init__(passes) diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index b101fa7b056..5abb426636d 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -22,14 +22,6 @@ ] -def _is_dequantize(node_: Node) -> bool: - return node_.op == "call_function" and node_.target in DEQUANTIZE_OPERATORS - - -def _is_quantize(node_: Node) -> bool: - return node_.op == "call_function" and node_.target in QUANTIZE_OPERATORS - - def input_tensor(node: Node, input_index: int) -> torch.Tensor: if len(node.all_input_nodes) <= input_index: raise IndexError @@ -103,3 +95,33 @@ def try_get_tensor_constant_from_node( return None attr_itr = getattr(attr_itr, atom) return attr_itr + + +def _is_dequantize(node_: Node) -> bool: + return node_.op == "call_function" and node_.target in [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + + +def _is_quantize(node_: Node) -> bool: + return node_.op == "call_function" and node_.target in [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + + +def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None: + """Return the first node which is not a `quantize` or `dequantize`, found by traversing the graph backwards + starting with the `node.args[input_index]`, + """ + current_node = node.args[input_index] + while True: + if _is_quantize(current_node) or _is_dequantize(current_node): + current_node = current_node.args[0] + else: + return current_node diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 6eef2017ec5..4189ac2dc47 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -19,7 +19,7 @@ from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec -from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember @@ -63,7 +63,7 @@ def convert_program( conversion_config: ConversionConfig = _default_conversion_config, neutron_target_spec: NeutronTargetSpec = _default_target_spec, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, - ) -> (bytes, dict): + ) -> (bytes, dict[str, NodeFormat]): """ Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes. @@ -87,13 +87,16 @@ def convert_program( self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc) self._process_nodes(edge_program.graph.nodes, cc) - # Assign output - io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats( - edge_program.graph_signature - ) + # Assign the model its inputs and outputs. + cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature) - # TFLite model generation + # Apply optimizations and finalize the model. internal_tflite_model = cc.tflite_builder.finish() + + # Extract the formats of the model's inputs and outputs. + io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature) + + # TFLite model generation flatbuffers_builder = flatbuffers.Builder() internal_tflite_model.gen_tflite(flatbuffers_builder) diff --git a/backends/nxp/backend/ir/conversion_config.py b/backends/nxp/backend/ir/conversion_config.py index 622735e881f..4ba66adc942 100644 --- a/backends/nxp/backend/ir/conversion_config.py +++ b/backends/nxp/backend/ir/conversion_config.py @@ -13,7 +13,7 @@ def __init__(self, args: dict | None = None): :param args: Optional dictionary with conversion arguments. Unknown arguments are ignored. """ - self.keep_io_format: bool = False + self.use_neutron_for_format_conversion: bool = True self.allow_inputs_stripping: bool = True self.qdq_aware_conversion: bool = True self.symbolic_dimensions_mapping: dict[str, int] | None = None diff --git a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py index 51a4a226fc8..658b4fc93f7 100644 --- a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py +++ b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py @@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]): self.check_and_append_operator(op) - def assign_model_io_to_subgraph_and_get_io_formats( - self, graph_signature - ) -> dict[str, dict]: - """ - Assign model's inputs/outputs to SubGraph. + def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]: + """Get a mapping from tensor names to their formats. - :param graph_signature: Instance of GraphSignature. + :param graph_signature: Instance of GraphSignature. :returns: Mapping between IO tensors' names and their formats. """ io_formats = { "inputs": {}, "outputs": {}, } + for input_name in graph_signature.user_inputs: + tensor = self.tensor_for_name(input_name) + assert input_name == tensor.name, ( + "Program's input name doesn't match with tensor name in TFLite. " + "Input was probably redirected." + ) + io_formats["inputs"][tensor.name] = tensor.tensor_format + + for output_name in graph_signature.user_outputs: + tensor = self.tensor_for_name(output_name) + assert output_name == tensor.name, ( + "Program's output name doesn't match with tensor name in TFLite. " + "Output was probably redirected." + ) + io_formats["outputs"][tensor.name] = tensor.tensor_format + + return io_formats + + def assign_model_io_to_subgraph(self, graph_signature): + """ + Assign model's inputs/outputs to SubGraph. + + :param graph_signature: Instance of GraphSignature. + """ self.get_sub_graph().inputs = tflite_model.SubGraphInputs() for input_name in graph_signature.user_inputs: @@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Input was probably redirected." ) self.get_sub_graph().inputs.tmp_inputs.append(tensor) - io_formats["inputs"][tensor.name] = tensor.tensor_format self.get_sub_graph().outputs = tflite_model.SubGraphOutputs() for output_name in graph_signature.user_outputs: @@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Output was probably redirected." ) self.get_sub_graph().outputs.tmp_outputs.append(tensor) - - io_formats["outputs"][tensor.name] = tensor.tensor_format - - return io_formats diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index 643a6231d15..cfd80d8e300 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -5,7 +5,9 @@ # License: MIT # See the LICENSE_MIT for more details. # + from copy import deepcopy +from itertools import chain from typing import Dict, List, Optional, Union import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator @@ -48,6 +50,9 @@ FlexTranspose, ) from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec @@ -218,7 +223,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor): new_tensor.shape = translator.channels_last_shape_to_channels_first( t_tensor.shape ) - new_tensor.tensor_format = new_tensor.tensor_format.to_node_format() + new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST perm = translator.create_channels_last_to_channels_first_permutation( t_tensor.rank @@ -355,6 +360,19 @@ def _make_inputs_channels_first(self): if input_tensor.tensor_format.is_channels_last(): # Create a Transpose operator and replace the graph input + new_input_shape = translator.channels_last_shape_to_channels_first( + input_tensor.shape + ) + perm = translator.create_channels_first_to_channels_last_permutation( + input_tensor.rank + ) + + if not transposition_is_supported_on_neutron( + new_input_shape.vector, list(perm), self.neutron_target_spec + ): + new_inputs.append(input_tensor) + continue + if input_tensor.rank > 6: msg = ( f"Couldn't preserve the shape of input tensor '{input_tensor.name}', because it has " @@ -365,14 +383,9 @@ def _make_inputs_channels_first(self): new_input = self.duplicate_tensor( input_tensor, input_tensor.name + "_channels_first" ) - new_input.shape = translator.channels_last_shape_to_channels_first( - input_tensor.shape - ) - new_input.tensor_format = input_tensor.tensor_format.to_node_format() + new_input.shape = new_input_shape + new_input.tensor_format = TensorFormat.CHANNELS_FIRST - perm = translator.create_channels_first_to_channels_last_permutation( - input_tensor.rank - ) transpose = self._create_transpose_operator( new_input, input_tensor, perm ) @@ -397,6 +410,16 @@ def _make_outputs_channels_first(self): if output_tensor.tensor_format.is_channels_last(): # Add a Transpose operator, to make the output channels first + shape = output_tensor.shape.vector + perm = translator.create_channels_last_to_channels_first_permutation( + len(shape), True + ) + if not transposition_is_supported_on_neutron( + shape, perm, self.neutron_target_spec + ): + new_outputs.append(output_tensor) + continue + if output_tensor.rank > 6: logger.e( logger.Code.IO_PRESERVATION_ERROR, @@ -437,6 +460,14 @@ def _keep_one_empty_buffer(self): # It's safe to replace the buffer. t.tmp_buffer = empty_buffer + def replace_io_tensor_format_with_node_format(self): + for t in chain( + self.get_sub_graph().inputs.tmp_inputs, + self.get_sub_graph().outputs.tmp_outputs, + ): + if isinstance(t.tensor_format, TensorFormat): + t.tensor_format = t.tensor_format.to_equal_node_format() + def finish(self) -> tflite_model.Model: """Finalize and optimize the converted TFLite model. Then return it. @@ -444,19 +475,23 @@ def finish(self) -> tflite_model.Model: :return: The final TFLite model. """ - if self.conversion_config.keep_io_format: + if self.conversion_config.use_neutron_for_format_conversion: # If the input or output is channels last, add a Transpose operator, to make is channels first. self._make_inputs_channels_first() self._make_outputs_channels_first() # Apply optimizations to the internal TFLite model. - optimizer.Optimizer(self, self.conversion_config).optimize( + optimizer.Optimizer( + self, self.conversion_config, self.neutron_target_spec + ).optimize( self.conversion_config.optimization_whitelist, self.conversion_config.optimization_blacklist, ) self._keep_one_empty_buffer() + self.replace_io_tensor_format_with_node_format() + # Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference. operator_outputs = [] for op in self.get_operators().vector: diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index 36266486aac..b69861f85b0 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -185,6 +185,14 @@ def builder(self) -> AtenModelBuilderDirector: """ return self.context.tflite_builder + @property + def neutron_target_spec(self) -> NeutronTargetSpec: + """ + Get an instance of NeutronTargetSpec from the conversion context. + :return: NeutronTargetSpec instance. + """ + return self.builder.neutron_target_spec + def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator: """ Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 9dea8ccd987..1990080e0f0 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -8,8 +8,10 @@ from executorch.backends.nxp.backend.custom_delegation_options import ( CustomDelegationOptions, ) +from executorch.backends.nxp.backend.edge_helper import previous_non_qdq_node from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + apply_permutation_to, create_channels_first_to_channels_last_permutation, ) from executorch.backends.nxp.backend.ir.converter.node_converter import ( @@ -23,6 +25,7 @@ from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter @@ -85,10 +88,6 @@ def _is_supported_on_target( dim = CatConverter._get_normalized_dim(node) - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1491 - if dim == 0: - return False - # Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the # last dimension, depending on the formats of the node. if node.meta[NXP_NODE_FORMAT].is_channels_first(): @@ -151,6 +150,46 @@ def _is_supported_in_IR( return True + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + ): + # There is a bug in the NeutronConverter, where if none of the input dimensions before the one referenced by + # `dim` are `!= 1`, the `Concat` is not delegated. + # This only happens when the inputs to the `Concat` are model inputs, and not outputs of other + # operators. + cat_partition = [p for p in partition_list if node in p.nodes][0] + cat_inputs = map(previous_non_qdq_node, node.args[0]) + + if not all( + input_.op == "call_function" and input_ in cat_partition.nodes + for input_ in cat_inputs + ): + # Some inputs of the `cat` are NOT in the same partition as `cat`. + dim = CatConverter._get_normalized_dim(node) + input_shapes = [list(n.meta["val"].shape) for n in node.args[0]] + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Transform the shapes to channels last. + to_nhwc_perm = create_channels_first_to_channels_last_permutation( + len(node.meta["val"].shape), True + ) + input_shapes = [ + apply_permutation_to(shape, to_nhwc_perm) for shape in input_shapes + ] + + # Transform the `dim` to refer to a channels last dimension. + dim = to_nhwc_perm.index(dim) + + for input_shape in input_shapes: + if not any(d != 1 for d in input_shape[:dim]): + # Do not delegate if there are no "non-1" dimensions in the shape before the `dim` dimension. + return False + + return True + def convert(self, node: Node): """Convert the 'aten.cat' operator to TFLite 'Concatenation'.""" self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index f32b5a65cac..645274c7870 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -3,8 +3,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import logging - import numpy as np import torch @@ -32,17 +30,20 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import ( ConvConversionResult, ConvParameters, + get_node_tensor_params, ) from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( set_quantization_parameters_to_tensor, ) from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( conv_2d_options, depthwise_conv_2d_options, reshape_options, + transpose_conv_options, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node @@ -57,18 +58,53 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - activations = node.args[0] + num_macs = neutron_target_spec.get_num_macs() + node_t_params = get_node_tensor_params(node) weights = node.args[1] - groups = node.args[8] + conv_params = ConvParameters( + *ConvolutionConverter._get_convolution_arguments(node) + ) - if activations.meta["val"].shape[0] != 1: + if node_t_params["batch_size"] != 1: # Only batch size 1 is supported on neutron. return False - if groups == 1: # Regular convolution. + if conv_params.transposed: + # TransposeConv1d is not supported on Neutron + if len(conv_params.dilation) == 1: + return False + if not node_is_effectively_static_tensor(weights, parameters_mapping): + # Only supported if the weights are static, because TFLite `TransposeConv` uses permuted + # weights. In case the weights are dynamic, a Transpose operator would have to be added, which + # is not supported on Neutron. + return False + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#876 TransposeConv2DKernelKind + if ( + conv_params.dilation != [1, 1] + or conv_params.padding[0] != 0 + or conv_params.padding[1] >= node_t_params["kernel_width"] + or ( + conv_params.padding[1] != 0 and node_t_params["inp_height"] != 1 + ) # Slice added by explicit padding + or conv_params.stride[0] != 1 + or ( + ( + conv_params.stride[1] != node_t_params["kernel_width"] / 2 + or node_t_params["out_height"] != 1 + ) + and conv_params.stride[1] != node_t_params["kernel_width"] + ) + or conv_params.stride[1] % 2 != 0 + or node_t_params["inp_channels"] % num_macs != 0 + or node_t_params["out_channels"] % num_macs != 0 + or node_t_params["kernel_width"] % 2 != 0 + or node_t_params["kernel_height"] != 1 + ): + return False + elif conv_params.groups == 1: # Regular convolution. pass elif conv_utils.group_conv_convertible_as_depthwise( - node, groups + node, conv_params.groups ): # Depthwise convolution. # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted # weights. In case the weights are dynamic, a Transpose operator would have to be added, which @@ -76,10 +112,9 @@ def _is_supported_on_target( if not node_is_effectively_static_tensor(weights, parameters_mapping): return False elif conv_utils.group_conv_convertible_into_multiple_convolutions( - node, groups - ): # Separable conv. This should never be reached, as the node should have been decomposed into - # multiple parallel convolutions by the `SplitGroupConvolution` pre-processing pass. - logging.warning("Group convolution was not decomposed.") + node, conv_params.groups + ): # Separable conv. + # Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron. return False else: # Unexpected case (should never happen). return False @@ -96,11 +131,15 @@ def _is_supported_in_IR( dimensions = input_tensor_rank - 2 is_transposed = node.args[6] output_padding = node.args[7] + groups = node.args[8] - if is_transposed: + if is_transposed and conv_utils.group_conv_convertible_as_depthwise( + node, groups + ): + # TFLite does not support transposed depthwise convolution return False - if output_padding != [0] * dimensions: + if not is_transposed and output_padding != [0] * dimensions: return False if input_tensor_safe(node, 2) is None: @@ -115,6 +154,20 @@ def _is_supported_in_IR( Transposed = bool Groups = int + def _compute_slicing_params( + self, output_shape, explicit_padding + ) -> tuple[list[int], list[int]]: + begins = [] + sizes = [] + + for axis in range(len(output_shape)): + (start, end) = explicit_padding[axis] + + begins.append(start) + sizes.append(output_shape[axis] - start - end) + + return begins, sizes + @staticmethod def _get_convolution_arguments( conv_node: Node, @@ -130,7 +183,7 @@ def _get_convolution_arguments( list(padding), list(dilation), transposed, - out_padding, + list(out_padding), groups, ) @@ -259,15 +312,16 @@ def _convert_unpadded_2D( [output_channels], "zero_bias", bias_type, False ) - # Compute scale and zero point for bias tensor - input_scale = np.array(x.quantization.scale.vector) - weight_scale = np.array(w.quantization.scale.vector) - bias_scale = input_scale * weight_scale - bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) + if w.type in [TensorType.INT8, TensorType.UINT8]: + # Compute scale and zero point for bias tensor + input_scale = np.array(x.quantization.scale.vector) + weight_scale = np.array(w.quantization.scale.vector) + bias_scale = input_scale * weight_scale + bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) - set_quantization_parameters_to_tensor( - b, bias_scale, bias_zero_point, quantized_dimension=0 - ) + set_quantization_parameters_to_tensor( + b, bias_scale, bias_zero_point, quantized_dimension=0 + ) # Assign the operator its TFLite inputs and outputs t_op.tmp_inputs = [x, w, b] @@ -278,87 +332,195 @@ def _convert_unpadded_2D( return conversion_result - def _convert_2d_conv( + def _convert_transpose_conv( self, t_op: tflite_model.Operator, conv_params: ConvParameters - ) -> list[tflite_model.Operator]: - if conv_utils.group_conv_convertible_as_depthwise( - t_op, conv_params.groups - ): # Convert to `DepthwiseConv2D`. - t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D() - - conversion_result = self._convert_unpadded_2D(t_op, conv_params) - t_op.builtin_options.padding, explicit_padding = ( - aten_translator.convert_padding(conv_params.padding) - ) - if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). - input_quantization = t_op.tmp_inputs[0].quantization - pad_value = ( - None - if input_quantization is None - else np.array(input_quantization.zero_point[0]).astype( - tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) - ) - ) - conversion_result.ops_list.add_pre( - self.builder.create_pad_operator_before( - t_op, 0, explicit_padding, constant_value=pad_value - ) - ) + ) -> conv_utils.ConvConversionResult: + """Convert the `aten.convolution` into TFLite TransposeConv. The `builtin_options` must be + converted by the caller. + """ + common.assign_2d_strides(t_op.builtin_options, conv_params.stride) - # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels] - perm = [3, 1, 2, 0] - weight_tensor = conversion_result.conv_weight_tensor - if tensor_has_data(weight_tensor): - # Transpose cloned tensor statically - t_op.tmp_inputs[1] = self.builder.create_transposed_tensor( - weight_tensor, perm - ) + x: tflite_model.Tensor = t_op.tmp_inputs[0] + w: tflite_model.Tensor = t_op.tmp_inputs[1] + y: tflite_model.Tensor = t_op.tmp_outputs[0] + + if (b := try_get_input(t_op, 2)) is None: + # Operator has no bias. Convolution aten op can omit it, TFLite can't. + # Weight tensor format in TFLite: [C, kH, kW, O] + # (C = input channels, O = output channels, kW = kernel width, kH = kernel height) + output_channels = w.shape.vector[-1] - if t_op.tmp_inputs[1].quantization is not None: - # Model is quantized - t_op.tmp_inputs[1].quantization.quantized_dimension = 3 + if w.type == TensorType.FLOAT32: + bias_type = np.dtype(np.float32) + elif w.type in [TensorType.INT8, TensorType.UINT8]: + bias_type = np.dtype(np.int32) else: - raise NotImplementedError("Dynamic Depthwise Conv weights.") + # Should never happen. + raise NotImplementedError( + f"Convolution node with unsupported weight type: {w.type}" + ) - elif conv_utils.group_conv_convertible_into_multiple_convolutions( - t_op, conv_params.groups - ): - # This case should have been rejected in the `is_supported_on_target()` method. - raise RuntimeError("Group convolution was not decomposed.") + b = self.builder.create_zeros_tensor( + [output_channels], "zero_bias", bias_type, True + ) + + if w.type in [TensorType.INT8, TensorType.UINT8]: + # Compute scale and zero point for bias tensor + input_scale = np.array(x.quantization.scale.vector) + weight_scale = np.array(w.quantization.scale.vector) + bias_scale = input_scale * weight_scale + bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) + + set_quantization_parameters_to_tensor( + b, bias_scale, bias_zero_point, quantized_dimension=0 + ) + # TransposeConv weight tensor format in TFLite: [O, kH, kW, C] + # (C = input channels, O = output channels, kW = kernel width, kH = kernel height) + if tensor_has_data(w): + # Transpose cloned tensor statically + w = self.builder.create_transposed_tensor(w, [3, 1, 2, 0]) + + if w.quantization is not None: + # Model is quantized + w.quantization.quantized_dimension = 0 else: - # Convert to regular `Conv2D`. - t_op.builtin_options = conv_2d_options.Conv2D() - conversion_result = self._convert_unpadded_2D(t_op, conv_params) - t_op.builtin_options.padding, explicit_padding = ( - aten_translator.convert_padding(conv_params.padding) + raise NotImplementedError("Dynamic Transpose Conv weights.") + w.tensor_format = TensorFormat.TRANSPOSE_CONV_2D_WEIGHT_FORMAT + + output_shape_tensor_data = np.asarray(y.shape.vector, dtype=np.int32) + o = self.builder.create_tensor_for_data( + output_shape_tensor_data, "output_shape" + ) + + # Assign the operator its TFLite inputs and outputs + t_op.tmp_inputs = [o, w, x, b] + t_op.tmp_outputs = [y] + conversion_result = ConvConversionResult(x, w, b, y, o) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) + ) + if explicit_padding is not None: + # Add padding to output shape to make sure we have computed all the data we need + for idx, padding in enumerate(explicit_padding): + output_shape_tensor_data[idx] += padding[0] + padding[1] + y.shape = tflite_model.Shape(output_shape_tensor_data.tolist()) + + # We need to "cut" produced tensor by size of explicit padding + begins, sizes = self._compute_slicing_params( + output_shape_tensor_data.tolist(), explicit_padding ) - if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). - input_quantization = t_op.tmp_inputs[0].quantization - pad_value = ( - None - if input_quantization is None - else np.array(input_quantization.zero_point[0]).astype( - tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) - ) + slice_op = self.builder.create_slice_after(t_op, 0, begins, sizes) + conversion_result.ops_list.add_post(slice_op) + + conversion_result.ops_list.middle_op = t_op + + return conversion_result + + def _convert_2d_conv( + self, t_op: tflite_model.Operator, conv_params: ConvParameters + ) -> list[tflite_model.Operator]: + if conv_params.transposed: + t_op.builtin_options = transpose_conv_options.TransposeConv() + if conv_utils.group_conv_convertible_into_multiple_convolutions( + t_op, conv_params.groups + ): + # Convert to separated `TransposeConv`. + raise NotImplementedError("Separated TransposeConv not implemented.") + else: + # Convert to `TransposeConv`. + conversion_result = self._convert_transpose_conv(t_op, conv_params) + + else: + if conv_utils.group_conv_convertible_as_depthwise( + t_op, conv_params.groups + ): # Convert to `DepthwiseConv2D`. + t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D() + + conversion_result = self._convert_unpadded_2D(t_op, conv_params) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) ) - conversion_result.ops_list.add_pre( - self.builder.create_pad_operator_before( - t_op, 0, explicit_padding, constant_value=pad_value + if explicit_padding is not None: + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) + conversion_result.ops_list.add_pre( + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, constant_value=pad_value + ) ) + + # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels] + perm = [3, 1, 2, 0] + weight_tensor = conversion_result.conv_weight_tensor + if tensor_has_data(weight_tensor): + # Transpose cloned tensor statically + t_op.tmp_inputs[1] = self.builder.create_transposed_tensor( + weight_tensor, perm + ) + + if t_op.tmp_inputs[1].quantization is not None: + # Model is quantized + t_op.tmp_inputs[1].quantization.quantized_dimension = 3 + else: + raise NotImplementedError("Dynamic Depthwise Conv weights.") + + elif conv_utils.group_conv_convertible_into_multiple_convolutions( + t_op, conv_params.groups + ): # Convert to separated `Conv2D`. + t_op.builtin_options = conv_2d_options.Conv2D() + + return conv_utils.create_separated_convolutions_based_on_group( + t_op, + conv_params, + self.builder, + self._convert_unpadded_2D, + conv_utils.conv_op_factory, + ) + + else: + # Convert to regular `Conv2D`. + t_op.builtin_options = conv_2d_options.Conv2D() + conversion_result = self._convert_unpadded_2D(t_op, conv_params) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) ) + if explicit_padding is not None: + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) + conversion_result.ops_list.add_pre( + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, constant_value=pad_value + ) + ) return conversion_result.ops_list.flatten() def convert(self, node: Node): self.assert_convertible(node) - stride, padding, dilation, _, _, groups = self._get_convolution_arguments(node) + stride, padding, dilation, transposed, out_padding, groups = ( + self._get_convolution_arguments(node) + ) t_op = self._create_tflite_op_with_io_tensors(node) - conv_params = ConvParameters(stride, padding, dilation, groups) + conv_params = ConvParameters( + stride, padding, dilation, transposed, out_padding, groups + ) rank = t_op.tmp_inputs[1].shape.len() if rank == 3: # Conv1D diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py index c1dd7b600be..ac09e564eb8 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/mean_dim_converter.py @@ -1,5 +1,4 @@ -# Copyright (c) 2025 NXP -# All rights reserved. +# Copyright 2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -20,6 +19,7 @@ mean_options, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -32,15 +32,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - dim = node.args[1] keepdim = node.args[2] if len(node.args) >= 3 else False rank = len(node.args[0].meta["val"].shape) - dim = [d - rank if d > 0 else d for d in dim] + dim = [MeanDimConverter._to_pos_dim(d, rank) for d in node.args[1]] + + if rank != 4 or not keepdim: + # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#74-77 + return False - # Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron. - if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim: + # The `mean.dim` gets converted to AveragePool by the NeutronConverter, so the channels must be a + # multiple of `num_macs`. + # neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#59-85 + num_macs = neutron_target_spec.get_num_macs() + channels_dim = 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else -1 + if (node.meta["val"].shape[channels_dim] % num_macs) != 0: return False + # Neutron only supports reduction over the spatial dimensions H, W. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # The input is NCHW. H and W are at indices 2 and 3. + if dim not in [[2, 3], [3, 2]]: + return False + else: + # The input is formatless. It can be considered as NHWC, as this is the way Neutron will look at + # the dimensions. So H and W are the middle dimensions. + if dim not in [[1, 2], [2, 1]]: + return False + return True @staticmethod diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py index f0150b4bc1f..35bef6c8035 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py @@ -4,28 +4,438 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch +from executorch.backends.nxp.backend.edge_helper import ( + node_is_effectively_static_tensor, +) +from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext from executorch.backends.nxp.backend.ir.converter import quantization_utils +from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + NeutronTargetSpec, NodeConverter, ) +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( transpose_options, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + is_tensor_invariant_permutation, + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter +Permutation = list[int] +PermutationSupportDict = dict[str, dict[str, bool | Permutation]] + + +def _get_shape(node: torch.fx.Node) -> list[int]: + return list(node.meta["val"].shape) + + +def get_supported_transpositions( + node: Node, neutron_target_spec: NeutronTargetSpec +) -> PermutationSupportDict: + """Since ExecuTorch and NeutronIR use different tensor formats, we must consider the different possible cases + which may occur. The main permutation is always done on channels_first/formatless data, and the output is + channels_first/formatless as well. If this is not the case, a `Transpose` is inserted before and/or + after the main `Transpose`, to make the input/output channels_first. These additional `Transpose` + ops must be supported by Neutron as well. Alternatively, consecutive `Transpose` ops can be fused + together. It is possible for a pair of unsupported permutation to result in a supported one. + Therefore, the merged permutations must also be considered. + + This function identifies which of these permutations are supported on neutron, and returns a dictionary with the + support summary and the corresponding permutations. + + :param node: The `permute_copy` node to base the support analysis from/ + :param neutron_target_spec: NeutronTagetSpec instance. + :return: A dictionary containing the support status and permutation, for all the possible permutations which may be + used during the conversion of the `node`. + """ + + input_shape = node.args[0].meta["val"].shape + output_shape = node.meta["val"].shape + perm = list(node.args[1]) + + to_nchw_perm = translator.create_channels_last_to_channels_first_permutation( + len(input_shape), True + ) + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + main_perm_supported = transposition_is_supported_on_neutron( + input_shape, perm, neutron_target_spec + ) + + # "To NCHW" permutation, in case the input is channels last. + separate_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, to_nchw_perm, neutron_target_spec + ) + # The main permutation and the previous one merged. + merged_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, + merged_pre_transpose_permutation := translator.combine_permutations( + to_nchw_perm, perm + ), + neutron_target_spec, + ) + + # "To NHWC" permutation after the main `Transpose`. + separate_post_transpose_supported = transposition_is_supported_on_neutron( + output_shape, to_nhwc_perm, neutron_target_spec + ) + + # The main permutation and the previous one merged. + merged_post_transpose_supported = transposition_is_supported_on_neutron( + input_shape, + merged_post_transpose_permutation := translator.combine_permutations( + perm, to_nhwc_perm + ), + neutron_target_spec, + ) + + # "To NCHW", main permutation, and "to NHWC" all merged. + everything_merged_supported = transposition_is_supported_on_neutron( + input_shape, + everything_merged_permutation := translator.combine_permutations( + translator.combine_permutations(to_nchw_perm, perm), to_nhwc_perm + ), + neutron_target_spec, + ) + + return { + "main": {"supported": main_perm_supported, "perm": perm}, + "separate_pre": { + "supported": separate_pre_transpose_supported, + "perm": to_nchw_perm, + }, + "merged_pre": { + "supported": merged_pre_transpose_supported, + "perm": merged_pre_transpose_permutation, + }, + "separate_post": { + "supported": separate_post_transpose_supported, + "perm": to_nhwc_perm, + }, + "merged_post": { + "supported": merged_post_transpose_supported, + "perm": merged_post_transpose_permutation, + }, + "everything_merged": { + "supported": everything_merged_supported, + "perm": everything_merged_permutation, + }, + } + + +class PermuteCopyFormatHandler: + def __init__(self, context: ConversionContext): + self.context = context + + @property + def neutron_target_spec(self): + return self.context.tflite_builder.neutron_target_spec + + @property + def builder(self): + return self.context.tflite_builder + + def _handle_channels_first_input_and_formatless_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The input must be permuted. + # Either combine the permutations, or prepend a `Transpose` operator. + + if node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + # The input is static, so the operator will be removed by an optimization. + perm = perm_dict["main"]["perm"] + + elif perm_dict["merged_pre"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_pre"]["perm"] + + elif perm_dict["separate_pre"]["supported"] and perm_dict["main"]["supported"]: + # Prepend a `Transpose` operator to make the input channels first. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_channels_first_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The output must be permuted. + # Either combine the permutations, or append a `Transpose` operator. + + if node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + # The input is static, so the operator will be removed by an optimization. + perm = perm_dict["main"]["perm"] + + elif perm_dict["merged_post"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_post"]["perm"] + + elif perm_dict["main"]["supported"] and perm_dict["separate_post"]["supported"]: + # Append a `Transpose` operator to make the output channels first. + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_channels_first_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Both input and output must be permuted, or some merged permutations must be supported. + if perm_dict["everything_merged"]["supported"]: + # Combine all 3 permutations into 1. + perm = perm_dict["everything_merged"]["perm"] + + elif ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Combine the input and main permutations, and append a `Transpose` to handle the output permutation. + perm = perm_dict["merged_pre"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ): + # Prepend a `Transpose` to handle the input permutation, and combine the main and output permutations. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["everything_merged"]["supported"] + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Handle each permutation separately. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Neither the input nor the output have to be permuted. + if perm_dict["main"]["supported"]: + perm = perm_dict["main"]["perm"] + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + return perm + + def handle_tensor_formats(self, t_op: tflite_model.Operator, node: Node) -> OpsList: + """Due to the different tensor formats used by ExecuTorch and NeutronIR, it may be necessary to modify the + permutation, or insert extra permutations to equalize the tensor formats. + This method identifies the four possible cases of input/output formats, and finds the conversion solution + which minimizes the number of necessary `Transpose` operators. + """ + perm_dict = get_supported_transpositions(node, self.neutron_target_spec) + + ops = OpsList(middle_op=t_op) + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + perm = self._handle_channels_first_input_and_formatless_output( + perm_dict, node, t_op, ops + ) + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + perm = self._handle_formatless_input_and_channels_first_output( + perm_dict, node, t_op, ops + ) + + elif input_format.is_channels_first() and output_format.is_channels_first(): + perm = self._handle_channels_first_input_and_output( + perm_dict, node, t_op, ops + ) + + else: + perm = self._handle_formatless_input_and_output(perm_dict, node, t_op, ops) + + perm_tensor = self.builder.create_tensor_for_data( + np.array(perm, "int32"), "perm" + ) + + # Use the final permutation as the operator's second input. + t_op.tmp_inputs = [t_op.tmp_inputs[0], perm_tensor] + + return ops + class PermuteCopyConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if node_is_effectively_static_tensor(node.args[0], parameters_mapping): + return ( + True # The operator computes on static data. It will be removed later. + ) + + input_shape = _get_shape(node.args[0]) + perm = list(node.args[1]) + + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + if is_tensor_invariant_permutation( + input_shape, perm + ) and is_tensor_invariant_permutation(channels_last_input_shape, perm): + # The `permute_copy` can always be represented as a Reshape. + return True + + perm_dict = get_supported_transpositions(node, neutron_target_spec) + + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + # Just the input must be permuted. + return ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_pre"]["supported"] + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + # Just the output must be permuted. + return ( + perm_dict["separate_post"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_post"]["supported"] + + elif input_format.is_channels_first() and output_format.is_channels_first(): + # Both input and output must be permuted. + return ( + # Separate IO transpositions. + ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Separate input, merged output. + or ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ) + # Merged input, separate output. + or ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Merged input and output. + or perm_dict["everything_merged"]["supported"] + ) + else: + # Simplest case. No format changes required. + return perm_dict["main"]["supported"] + @staticmethod def _is_supported_in_IR( node: Node, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: + if not NodeConverter._has_shared_q_params_if_quantized(node): + return False + return True def convert(self, node: Node): @@ -53,13 +463,6 @@ def convert(self, node: Node): "match. This indicates error in quantizer." ) - perm = np.array(node.args[1], "int32") - perm_tensor = self.builder.create_tensor_for_data(perm, "perm") - - # Assign the operator its TFLite inputs and outputs - t_op.tmp_inputs = [x, perm_tensor] - t_op.tmp_outputs = [y] - - ops_to_add = OpsList(middle_op=t_op) + ops = PermuteCopyFormatHandler(self.context).handle_tensor_formats(t_op, node) - self.builder.append_operators(ops_to_add.flatten()) + self.builder.append_operators(ops.flatten()) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py index 1d7c6b44627..3e20e504e8a 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py @@ -2,11 +2,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from abc import ABC, abstractmethod import numpy as np from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + create_channels_last_to_channels_first_permutation, torch_type_to_numpy_type, ) from executorch.backends.nxp.backend.ir.converter.node_converter import ( @@ -16,6 +18,8 @@ from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( set_quantization_parameters_to_tensor, ) +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.tflite_generator.tflite_model import Tensor from torch.fx import Node from torch.nn import Parameter @@ -50,6 +54,9 @@ def convert(self, node: Node): scale = self.get_scale(node) zero_point = self.get_zero_point(node) + quantized_dimension = 0 + if isinstance(self, QDQPerChannelDequantizeConverter): + quantized_dimension = self.get_quantization_dimension(from_tensor, node) if self.context.parameters_mapping.get(node.args[0].name, None) is None: # Convert dequantize as identity op (Transpose that will be removed) because @@ -57,15 +64,21 @@ def convert(self, node: Node): # here we will change input name of the model. t_op = self._create_tflite_op_with_io_tensors(node) - set_quantization_parameters_to_tensor(to_tensor, scale, zero_point, 0) - set_quantization_parameters_to_tensor(from_tensor, scale, zero_point, 0) + set_quantization_parameters_to_tensor( + to_tensor, scale, zero_point, quantized_dimension + ) + set_quantization_parameters_to_tensor( + from_tensor, scale, zero_point, quantized_dimension + ) from_tensor.type = to_tensor.type self.builder.turn_operator_to_identity(t_op) self.builder.append_operators([t_op]) else: # Dequantize consumes tensor with static data -> convert as a tensor - set_quantization_parameters_to_tensor(to_tensor, scale, zero_point, 0) + set_quantization_parameters_to_tensor( + to_tensor, scale, zero_point, quantized_dimension + ) # Change type so we pass check tensor similarity check when redirecting from_tensor.type = to_tensor.type @@ -89,3 +102,15 @@ def get_zero_point(self, node: Node) -> np.ndarray: def get_scale(self, node: Node) -> np.ndarray: return self.context.parameters_mapping[node.args[1].name].numpy() + + def get_quantization_dimension(self, from_tensor: Tensor, node: Node) -> int: + quantization_dimension = node.args[3] + + # Quantization dimension is affected by tensor format + if from_tensor.tensor_format == TensorFormat.CHANNELS_LAST: + tensor_rank = len(from_tensor.shape.vector) + perm = create_channels_last_to_channels_first_permutation( + tensor_rank, return_list=True + ) + quantization_dimension = perm[quantization_dimension] + return quantization_dimension diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py index 5817fd127b3..2012ecc8640 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py @@ -16,6 +16,8 @@ class ConvParameters: stride: list[int] padding: list[int] dilation: list[int] + transposed: bool + out_padding: list[int] groups: int @@ -35,6 +37,29 @@ def _get_IO_channels(node: Node | tflite_model.Operator) -> (int, int): return input_channels, output_channels +def get_node_tensor_params(node: Node) -> dict: + node_tensor_params = {} + + input_tensor = node.args[0] + assert len(input_tensor.meta["val"].shape) in [3, 4], "Supports only Conv 1D, 2D." + node_tensor_params["batch_size"] = input_tensor.meta["val"].shape[0] + node_tensor_params["inp_channels"] = input_tensor.meta["val"].shape[1] + node_tensor_params["inp_height"] = input_tensor.meta["val"].shape[2] + if len(input_tensor.meta["val"].shape) == 4: + node_tensor_params["inp_width"] = input_tensor.meta["val"].shape[3] + + weights = node.args[1] + node_tensor_params["out_channels"] = node.meta["val"].shape[1] + node_tensor_params["out_height"] = node.meta["val"].shape[2] + if len(node.meta["val"].shape) == 4: + node_tensor_params["out_width"] = node.meta["val"].shape[3] + node_tensor_params["kernel_height"] = weights.meta["val"].shape[2] + if len(weights.meta["val"].shape) == 4: + node_tensor_params["kernel_width"] = weights.meta["val"].shape[3] + + return node_tensor_params + + def group_conv_convertible_as_depthwise(node: Node | tflite_model.Operator, group: int): input_channels, output_channels = _get_IO_channels(node) @@ -70,9 +95,11 @@ def __init__( weight_tensor: tflite_model.Tensor, bias_tensor: tflite_model.Tensor, output_tensor: tflite_model.Tensor, + output_shape_tensor: tflite_model.Tensor | None = None, ): self.conv_input_tensor = input_tensor self.conv_weight_tensor = weight_tensor self.conv_bias_tensor = bias_tensor self.conv_output_tensor = output_tensor + self.output_shape_tensor = output_shape_tensor self.ops_list = OpsList() diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index 32967ff047a..71b697a0eba 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -38,8 +38,10 @@ def is_channels_last(self) -> bool: @staticmethod def from_node_format(node_format: NodeFormat): - if node_format.is_channels_first(): - return TensorFormat.CHANNELS_LAST + if node_format == NodeFormat.CHANNELS_FIRST: + return TensorFormat.CHANNELS_LAST # Format is swapped. + elif node_format == NodeFormat.CHANNELS_LAST: + return TensorFormat.CHANNELS_FIRST # Format is swapped. elif node_format == NodeFormat.FORMATLESS: return TensorFormat.FORMATLESS else: @@ -47,8 +49,21 @@ def from_node_format(node_format: NodeFormat): def to_node_format(self): if self == TensorFormat.CHANNELS_LAST: - return NodeFormat.CHANNELS_FIRST + return NodeFormat.CHANNELS_FIRST # Format is swapped. elif self == TensorFormat.FORMATLESS: return NodeFormat.FORMATLESS + elif self == TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_LAST # Format is swapped. else: return NodeFormat.NONE + + def to_equal_node_format(self): + match self: + case TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_FIRST + case TensorFormat.CHANNELS_LAST: + return NodeFormat.CHANNELS_LAST + case TensorFormat.FORMATLESS: + return NodeFormat.FORMATLESS + case _: + return NodeFormat.NONE diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py index 6001ca961b8..18e397cc1bd 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py @@ -12,16 +12,21 @@ InputTensorToOpsMap, OutputTensorToOpMap, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class BaseOptimization(ABC): _builder: "model_builder.ModelBuilder" def __init__( - self, builder: "model_builder.ModelBuilder", conversion_config: ConversionConfig + self, + builder: "model_builder.ModelBuilder", + conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self._conversion_config = conversion_config + self.neutron_target_spec = neutron_target_spec def _create_tensor_to_operator_dictionaries( self, diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py deleted file mode 100755 index 4d10b7c80ae..00000000000 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/move_relu_before_concat.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 NXP -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from copy import deepcopy - -from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.ir.tflite_optimizer.operator_rules import ( - AllInputsComeFrom, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.base_optimization import ( - BaseOptimization, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.pattern_matcher import ( - Op, - PatternMatcher, -) -from executorch.backends.nxp.backend.ir.tflite_optimizer.tensor_rules import ( - TensorHasOneConsumer, - TensorsHaveSameQuantization, -) - - -class MoveActivationBeforeConcatenation(BaseOptimization): - """ - Move some operators around in the following pattern. - This is a common pattern that emerges from the conversion of separable convolutions. - - │ │ │ │ - ┌───▼────┐ ┌───▼────┐ ┌───▼────┐ ┌───▼────┐ - │ Conv2D │ ... │ Conv2D │ │ Conv2D │ ... │ Conv2D │ - └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘ - └──┐ ┌──┘ │ │ - ┌──▼──────────▼─┐ ┌──▼───┐ ┌──▼───┐ - │ Concatenation │ ─────► │ Relu │ ... │ Relu │ - └───────┬───────┘ └──┬───┘ └──┬───┘ - │ 'x' └──┐ ┌──┘ - ┌──▼───┐ ┌──▼──────────▼─┐ - │ Relu │ │ Concatenation │ - └──┬───┘ └───────┬───────┘ - │ 'y' │ - """ - - activations = ["Relu", "ReluN1To1", "Relu6", "Tanh", "Sign"] - - def __call__(self) -> bool: - matcher = PatternMatcher( - self._builder, - [ - Op(["Concatenation"], None, ["x"], [AllInputsComeFrom("Conv2D")]), - Op(self.activations, ["x"], ["y"]), - ], - [ - TensorHasOneConsumer("x"), - # If the activation function is not changing the quantization parameters, it can be moved without - # messing with the quantization elsewhere. - TensorsHaveSameQuantization(["x", "y"]), - ], - ) - - to_remove = [] - - # Mapping an operator to a list of operators. These operators (value) will later be added into the TFLite - # model's `operators` in front of the specified operator (key). - to_add: dict[tflite_model.Operator, list[tflite_model.Operator]] = defaultdict( - lambda: [] - ) - - for [concat, activation], _, _, _ in matcher.match_patterns(): - new_concat_inputs = [] - for concat_input in concat.tmp_inputs: - # Create a new operator for the activation function. - new_activation = deepcopy(activation) - new_activation.tmp_inputs = [concat_input] - new_activation_output = self._builder.duplicate_tensor(concat_input) - new_activation.tmp_outputs = [new_activation_output] - - to_add[concat].append( - new_activation - ) # Insert the new activation into the model later. - - new_concat_inputs.append( - new_activation_output - ) # Connect the activation with the `Concatenation`. - - concat.tmp_inputs = new_concat_inputs - - # Tensor rule ensures that only the activation functions is using the output of the `Concatenation`. - # It is safe to bypass. - concat.tmp_outputs[0] = activation.tmp_outputs[0] - to_remove.append(activation) - - operators = self._builder.get_operators() - - # Add the new activations into the model. - for concat, activations in to_add.items(): - idx = operators.index(concat) - for activation in activations: - operators.insert(idx, activation) - - # Remove the old activations. - for activation in to_remove: - operators.remove(activation) - - return len(to_remove) != 0 diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py index 0be46efcaa8..053e53d9df8 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py @@ -24,10 +24,14 @@ TensorIsNotModelOutput, TensorsHaveData, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) class FuseTransposeOperators(BaseOptimization): - """Remove some `Transpose` operators in the following pattern. + """Remove some `Transpose` operators in the following pattern. This is only done if the resulting permutation is + supported on Neutron. │ 'x' ┌─────▼─────┐ @@ -61,12 +65,27 @@ def __call__(self) -> bool: ) in matcher.match_patterns(): x = tensor_map["x"] perm1 = tensor_map["perm1"].tmp_buffer.data + combined_perms = [] # Remove the leading transpose. for second_transpose in following_transposes: # Combine the permutations for a new permutation of the second `Transpose`. perm2 = second_transpose.tmp_inputs[1].tmp_buffer.data - combined_perm = np.array(combine_permutations(perm1, perm2), np.int32) + combined_perms.append( + np.array(combine_permutations(perm1, perm2), np.int32) + ) + + if not all( + transposition_is_supported_on_neutron( + x.shape.vector, list(perm), self.neutron_target_spec + ) + for perm in combined_perms + ): + continue # Avoid creating an unsupported permutation. + + for second_transpose, combined_perm in zip( + following_transposes, combined_perms + ): second_transpose.tmp_inputs[1] = self._builder.create_tensor_for_data( combined_perm, "perm" ) diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py index 3611c55e995..1a96422e377 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py @@ -11,9 +11,6 @@ from executorch.backends.nxp.backend.ir import logger from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig -from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.move_relu_before_concat import ( - MoveActivationBeforeConcatenation, -) from executorch.backends.nxp.backend.ir.tflite_optimizer.optimizations.permute_fully_connected_weights_after_reshape import ( PermuteFullyConnectedWeightsAfterReshape, ) @@ -21,6 +18,7 @@ FuseTransposeOperators, RemoveIdentityTransposeOperators, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class Optimization(Enum): @@ -29,8 +27,6 @@ class Optimization(Enum): PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE = 12 - MOVE_ACTIVATION_BEFORE_CONCAT = 15 - class Optimizer: """ @@ -55,21 +51,19 @@ def __init__( self, builder: "model_builder.ModelBuilder", # noqa F821 conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self.optimization_map = { Optimization.FUSE_TRANSPOSE_OPERATORS: FuseTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.REMOVE_IDENTITY_TRANSPOSE_OPERATORS: RemoveIdentityTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape( - builder, conversion_config - ), - Optimization.MOVE_ACTIVATION_BEFORE_CONCAT: MoveActivationBeforeConcatenation( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), } diff --git a/backends/nxp/backend/neutron_converter_manager.py b/backends/nxp/backend/neutron_converter_manager.py index a6884a9ee24..a53a773f2ce 100644 --- a/backends/nxp/backend/neutron_converter_manager.py +++ b/backends/nxp/backend/neutron_converter_manager.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import importlib import logging import multiprocessing @@ -75,24 +76,37 @@ def convert(self, tflite_model: bytes, target: str) -> bytes: cctx = self.neutron_converter.CompilationContext() cctx.targetOpts = self.neutron_converter.getNeutronTarget(target) cctx.compilationOpts.minNumOpsPerGraph = 1 + cctx.compilationOpts.excludeGraphPasses = "MergeTranspose" + + # Try to use multiprocessing for isolation, but fall back to direct execution + # if the environment doesn't support it (e.g., in sandcastle/build environments) + try: + logger = multiprocessing.log_to_stderr() + logger.setLevel(logging.WARNING) + queue = multiprocessing.Manager().Queue() + + process = multiprocessing.Process( + target=convert_unsafe, + args=(self.neutron_converter, tflite_model, cctx, queue), + ) + process.start() + process.join() # waits until the subprocess is complete - logger = multiprocessing.log_to_stderr() - logger.setLevel(logging.WARNING) - queue = multiprocessing.Manager().Queue() - - process = multiprocessing.Process( - target=convert_unsafe, - args=(self.neutron_converter, tflite_model, cctx, queue), - ) - process.start() - process.join() # waits until the subprocess is complete + if queue.empty(): # signals the unsafe task did not run till the end + raise RuntimeError( + f"Neutron converter module terminated unexpectedly with exit code {process.exitcode}" + ) - if queue.empty(): # signals the unsafe task did not run till the end - raise RuntimeError( - f"Neutron converter module terminated unexpectedly with exit code {process.exitcode}" + model_converted = queue.get() + process.close() + except (EOFError, OSError) as e: + # Multiprocessing failed (likely due to environment restrictions) + # Fall back to direct execution + logging.warning( + f"Multiprocessing not available ({e}), running neutron converter directly" + ) + model_converted = self.neutron_converter.convertModel( + list(tflite_model), cctx ) - model_converted = queue.get() - - process.close() return bytes(model_converted) diff --git a/backends/nxp/backend/neutron_operator_support.py b/backends/nxp/backend/neutron_operator_support.py new file mode 100644 index 00000000000..cdb46870b2e --- /dev/null +++ b/backends/nxp/backend/neutron_operator_support.py @@ -0,0 +1,79 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + + +def is_tensor_invariant_permutation( + input_shape: list[int], permutation: list[int] +) -> bool: + def input_dim_is_not_one(index): + return input_shape[index] != 1 + + new_permutation = list(filter(input_dim_is_not_one, permutation)) + + return new_permutation == sorted(new_permutation) + + +def transposition_is_supported_on_neutron( + input_shape: list[int], + permutation: list[int], + neutron_target_spec: NeutronTargetSpec, +) -> bool: + """This function determines if the current NeutronSoftware properly supports a `Transpose` operator with given + `input_shape` and `permutation`. + + :param input_shape: The shape of the main input tensor of the `Transpose` operator. + :param permutation: The permutation the `Transpose` operator is computing. + :param neutron_target_spec: Object for querying the target platform to retrieve its properties. + """ + num_macs = neutron_target_spec.get_num_macs() + + if is_tensor_invariant_permutation(input_shape, permutation): + # The `Transpose` will be turned into a `Reshape` by Neutron. The check includes the identity permutation. + return True + + if permutation == [0, 3, 1, 2]: + # NHWC -> NCHW + n, h, w, c = input_shape + + if h * w * c % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 3, 1, 2] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + elif permutation == [0, 2, 3, 1]: + # NCHW -> NHWC + + n, c, h, w = input_shape + + if w % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 2, 3, 1] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + return False diff --git a/backends/nxp/backend/node_format.py b/backends/nxp/backend/node_format.py index 91049c200d7..fd54e2365ed 100644 --- a/backends/nxp/backend/node_format.py +++ b/backends/nxp/backend/node_format.py @@ -19,5 +19,8 @@ class NodeFormat(Enum): # Format has not been identified NONE = 2 + # NHWC + CHANNELS_LAST = 3 + def is_channels_first(self) -> bool: return self == NodeFormat.CHANNELS_FIRST diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 78f8dff8c32..244fd76d588 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -30,7 +30,10 @@ class NodeFormatInference: # A set of Edge Aten ops, which have the ability to change the format (for example - input nodes # are channels first but output is formatless). - ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default} + ops_that_can_change_tensor_format = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + } _type_changed_during_last_run: bool @@ -88,11 +91,23 @@ def _infer_format_of_nodes(self, node: Node): if op_type in self.ops_with_channels_first_nodes: self._handle_node_which_uses_channels_first_format(node) + elif op_type in self.ops_that_can_change_tensor_format: - if op_type == exir_ops.edge.aten.view_copy.default: # view_copy + if op_type in [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + ]: + # Try to assign the `formatless` format to the input and output. The converter will then handle the + # transition. + # Note: If the format for the input/output has already been assigned as channels first, it will NOT be + # overwritten. self._assign_format_to_node( self._node_outputs[node][0], NodeFormat.FORMATLESS ) + self._assign_format_to_node( + self._node_inputs[node][0], NodeFormat.FORMATLESS + ) + else: logger.error( f"Node format inference for node type: {op_type} not found!" diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 6be4495d615..f89bac55bc5 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -208,12 +208,13 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]): exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405 exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 + exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 + exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405 exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 - exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 } diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index b133a588c03..457fa335ba6 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -14,16 +14,17 @@ import numpy as np import torch -from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass +from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NodeFormat from executorch.backends.nxp.neutron_node_extraction import ( extract_artifacts_from_neutron_node, NeutronNodeArtifacts, @@ -44,6 +45,7 @@ def __init__(self): self.output_format = None self.operators_not_to_delegate: List[str] = [] self.neutron_converter_flavor = None + self.use_neutron_for_format_conversion = True def _replace_colons(self, operator: str) -> str: """ @@ -57,6 +59,7 @@ def neutron_compile_spec( neutron_converter_flavor: str, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ): """ Generate compile spec for Neutron NPU @@ -67,6 +70,9 @@ def neutron_compile_spec( "'neutron_converter_SDK_25_09' has flavor 'SDK_25_09'. extra_flags: Extra flags for the Neutron compiler operators_not_to_delegate: List of operators that should not be delegated + use_neutron_for_format_conversion: If True, the EdgeProgramToIRConverter will insert `Transpose` ops to + ensure that the IO matches the executorch partition, which will be + delegated to Neutron. """ self.neutron_converter_flavor = neutron_converter_flavor @@ -86,6 +92,8 @@ def neutron_compile_spec( self._replace_colons(op) for op in operators_not_to_delegate ] + self.use_neutron_for_format_conversion = use_neutron_for_format_conversion + return self def build(self): @@ -104,6 +112,10 @@ def build(self): "operators_not_to_delegate", ",".join(self.operators_not_to_delegate).encode(), ), + CompileSpec( + "use_neutron_for_format_conversion", + f"{self.use_neutron_for_format_conversion}".encode(), + ), ] return self.compile_spec @@ -115,6 +127,7 @@ def generate_neutron_compile_spec( system_config: Optional[str] = None, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ) -> List[CompileSpec]: return ( NeutronCompileSpecBuilder() @@ -123,6 +136,7 @@ def generate_neutron_compile_spec( neutron_converter_flavor, extra_flags=extra_flags, operators_not_to_delegate=operators_not_to_delegate, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) .build() ) @@ -145,6 +159,7 @@ def preprocess( # noqa C901 binary = bytes() target = "" neutron_converter_flavor = "" + use_neutron_for_format_conversion = None for spec in compile_spec: if spec.key == "output_format": output_format = spec.value.decode() @@ -154,6 +169,8 @@ def preprocess( # noqa C901 compile_flags.append(spec.value.decode()) if spec.key == "neutron_converter_flavor": neutron_converter_flavor = spec.value.decode() + if spec.key == "use_neutron_for_format_conversion": + use_neutron_for_format_conversion = spec.value.decode() == "True" # Check that the output format is set in the compile spec if not output_format: @@ -180,9 +197,15 @@ def preprocess( # noqa C901 ).transform() # Convert the edge program to TFLite. + conversion_config = ConversionConfig( + {"use_neutron_for_format_conversion": use_neutron_for_format_conversion} + if use_neutron_for_format_conversion is not None + else {} + ) tflite_model, io_formats = EdgeProgramToIRConverter().convert_program( edge_program, neutron_target_spec=NeutronTargetSpec(target, neutron_converter_flavor), + conversion_config=conversion_config, ) neutron_model = NeutronConverterManager(neutron_converter_flavor).convert( @@ -241,7 +264,9 @@ def _format_string_for_array(self, array: np.ndarray) -> str: return f"{array.size}s{self._padding_format_string_for_array(array)}" - def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: + def _create_payload_header( + self, io_formats: dict[str, list[NodeFormat]], neutron_artifacts + ) -> np.ndarray: """ Create bytes header for returned payload. It contains information about input and output tensor formats. Tensors are ordered based on graph signature @@ -279,9 +304,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: for input_name in neutron_artifacts.input_names: try: header_data.append( - 1 - if inputs[input_name.decode()] == TensorFormat.CHANNELS_LAST - else 0 + 1 if inputs[input_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: raise AssertionError( @@ -292,7 +315,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: try: header_data.append( 1 - if outputs[output_name.decode()] == TensorFormat.CHANNELS_LAST + if outputs[output_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: @@ -331,7 +354,9 @@ def _pack_with_alignment( neutron_artifacts.kernels.tobytes(), ) - def get_binary_payload(self, io_formats, neutron_model) -> bytes: + def get_binary_payload( + self, io_formats: dict[str, list[NodeFormat]], neutron_model + ) -> bytes: """ Get binary payload for provided input/output tensor formats and neutron_model. Returned data have following structure: @@ -351,7 +376,7 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes: Tensor format definition: '0x1' == CHANNELS_LAST, '0x0' == FORMATLESS (no format). :param io_formats: Dictionary with keys 'inputs' and 'outputs' that contains dictionaries - mapping tensor name to TensorFormat. + mapping tensor name to NodeFormat. :param neutron_model: Neutron model with single NeutronGraph node. :return: 16 bytes aligned binary payload. """ diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index 6564c19d7b9..24fe13555ca 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -12,6 +12,7 @@ from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from executorch.backends.nxp.quantizer.patterns import ( AbsPattern, + ActivationsConcatClusterPattern, AdaptiveAvgPoolPattern, AddmmPattern, AddTensorPattern, @@ -19,6 +20,7 @@ CatPattern, Conv1dPattern, Conv2dPattern, + ConvTranspose2dPattern, DropoutPattern, FlattenPattern, HardTanhInPlacePattern, @@ -40,6 +42,7 @@ SubTensorPattern, TanhInPlacePattern, TanhPattern, + TransposeIntPattern, ViewPattern, ) from executorch.backends.nxp.quantizer.utils import ( @@ -196,6 +199,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec): NeutronAtenQuantizer(CatPattern(), static_qconfig), NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig), + NeutronAtenQuantizer(ConvTranspose2dPattern(), static_qconfig), NeutronAtenQuantizer(DropoutPattern(), static_qconfig), NeutronAtenQuantizer(FlattenPattern(), static_qconfig), NeutronAtenQuantizer(HardTanhPattern(), static_qconfig), @@ -214,6 +218,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec): NeutronAtenQuantizer(SubTensorPattern(), static_qconfig), NeutronAtenQuantizer(TanhPattern(), static_qconfig), NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), + NeutronAtenQuantizer(TransposeIntPattern(), static_qconfig), NeutronAtenQuantizer(ViewPattern(), static_qconfig), ] ) @@ -225,13 +230,16 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec): self.op_to_applied_quantizer = { pt: False for q in self.quantizers for pt in q.pattern.partition_types() } + self.cluster_quantizers = [ + NeutronAtenQuantizer(ActivationsConcatClusterPattern(self), static_qconfig) + ] def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes. - model = NeutronAtenPassManager()(model).graph_module + model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes. @@ -240,6 +248,10 @@ def transform_for_annotation( def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: self._annotate_inputs(model) + # Annotate node clusters in model + for cluster_quantizer in self.cluster_quantizers: + cluster_quantizer.annotate(model) + nodes = list(model.graph.nodes) for node in nodes: if ( diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index ccd579d5c52..90c43d1971e 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.quantizer.utils import get_bias_qparams from torch import fx from torch._ops import OpOverload +from torch.fx import Node from torchao.quantization.pt2e import PerChannelMinMaxObserver from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, @@ -20,6 +21,7 @@ QuantizationSpec, SharedQuantizationSpec, ) + from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY @@ -95,6 +97,7 @@ class SharedSpecPattern(QuantizationPattern): quantization parameters (scale and zero-point). """ + @abstractmethod def partition_types(self) -> list[torch.nn.Module]: pass @@ -199,7 +202,6 @@ def partition_types(self) -> list[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... addmm_node = fused_partition[0].nodes[-1] bias_qspec = DerivedQuantizationSpec( @@ -391,6 +393,11 @@ def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv1d.default] +class ConvTranspose1dPattern(ConvPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv_transpose1d.default] + + class Conv2dPattern(ConvPattern): def __init__(self, neutron_quantizer): self.neutron_quantizer = neutron_quantizer @@ -456,6 +463,51 @@ def get_anchors( ) +class ConvTranspose2dPattern(QuantizationPattern): + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv_transpose2d.input] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + conv_node = fused_partition[0].nodes[-1] + + bias_quantization_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv_node.args[0], conv_node), + (conv_node.args[1], conv_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31) + 1, + quant_max=2**31 - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=1, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)] + + return PartitionAnchors( + inputs=[(conv_node, NodeArgsIdx(0))], + weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)], + biases=bias, + output=[(conv_node,)], + ) + + class DropoutPattern(SharedSpecPattern): """ Quantizer for Dropout operator. @@ -639,6 +691,15 @@ def partition_types(self): return [torch.ops.aten.permute.default] +class TransposeIntPattern(SharedSpecPattern): + """ + Quantizer for Transpose Int operator. + """ + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.transpose.int] + + class ReluPattern(SingleInputBasicPattern): """ Quantizer for Relu operator. @@ -745,3 +806,147 @@ def get_anchors( return get_anchors_for_fixed_quant_specs( fused_partition, scale=1.0 / 128.0, zero_point=0 ) + + +class ActivationsConcatClusterPattern(QuantizationPattern): + """ + Quantizer for activations concat cluster pattern. + + The quantizer matches a pattern where concat node is preceded by activation nodes preceded by Conv 2D or Linear. + All activation nodes quantization parameters must be the same. Only activations, that have support for fusion + to preceding compute node on Neutron are allowed. This cluster is usually produced by MoveActivationBeforeConcat + pass. Cluster schema: + + │ │ + ┌──────▼──────┐ ┌──────▼──────┐ + │ aten.conv2d │ ... │ aten.conv2d │ + └──────┬──────┘ └──────┬──────┘ + │ │ + ┌─────▼─────┐ ┌─────▼─────┐ + │ aten.relu │ ... │ aten.relu │ + └─────┬─────┘ └─────┬─────┘ + └───────┐ ┌───────┘ + ┌──▼─────▼─┐ + │ aten.cat │ + └────┬─────┘ + │ + """ + + def __init__(self, neutron_quantizer): + self.neutron_quantizer = neutron_quantizer + self.neutron_target_info = ( + self.neutron_quantizer.neutron_target_spec.neutron_target_info + ) + + @staticmethod + def _all_activations_are_equal(activations: list[Node]) -> bool: + first_input_node = activations[0] + hardtanh_t = [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + ] + relu_t = [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ] + tanh_t = [ + torch.ops.aten.tanh.default, + torch.ops.aten.tanh_.default, + ] + + def _activations_are_equal(activation1: Node, activation2: Node) -> bool: + if ( # Targets are equal also with their inplace variants + (activation1.target in hardtanh_t and activation2.target in hardtanh_t) + or (activation1.target in relu_t and activation2.target in relu_t) + or (activation1.target in tanh_t and activation2.target in tanh_t) + or ( + activation1.target == torch.ops.aten.sigmoid.default + and activation2.target == torch.ops.aten.sigmoid.default + ) + ): + return True + elif ( # Hardtanh with min_val 0 and max_val 'inf' is equal to Relu + activation1.target in hardtanh_t + and activation1.args[1:] == (0.0, float("inf")) + and activation2.target in relu_t + ) or ( + activation1.target in relu_t + and activation2.target in hardtanh_t + and activation2.args[1:] == (0.0, float("inf")) + ): + return True + else: + return False + + return all( + _activations_are_equal(activation, first_input_node) + for activation in activations + ) + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.cat.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: + cat_node = fused_partition[0].nodes[-1] + + # Check all cat inputs are supported activations + if not all( + self.neutron_target_info.is_supported_fused_activation__aten(input_node) + for input_node in cat_node.all_input_nodes + ): + return None + + # Check all cat inputs are equal activations + if not self._all_activations_are_equal(cat_node.all_input_nodes): + return None + + # Check compute nodes are Conv 2D or Linear + if not all( + self.neutron_target_info.is_fusable_conv_or_linear__aten(compute_node) + for input_node in cat_node.all_input_nodes + for compute_node in input_node.all_input_nodes + ): + return None + + # Annotate compute nodes + for input_node in cat_node.all_input_nodes: + for compute_node in input_node.all_input_nodes: + if compute_node.target not in self.neutron_quantizer.op_to_quantizer: + return None + compute_node_quantizer = self.neutron_quantizer.op_to_quantizer[ + compute_node.target + ] + compute_node_quantizer.annotate(gm) + del compute_node.meta["quantization_annotation"].output_qspec + + # Annotate activations + for input_node in cat_node.all_input_nodes: + if input_node.target not in self.neutron_quantizer.op_to_quantizer: + return None + activation_quantizer = self.neutron_quantizer.op_to_quantizer[ + input_node.target + ] + activation_quantizer.annotate(gm) + input_node.meta["quantization_annotation"].input_qspec_map = {} + + # Annotate cat node + inputs = [] + first_input_node = cat_node.all_input_nodes[0] + for idx in range(len(cat_node.all_input_nodes)): + inputs.append( + ( + cat_node, + NodeArgsIdx(0, idx), + SharedQuantizationSpec(first_input_node), + ) + ) + outputs = [(cat_node, SharedQuantizationSpec(first_input_node))] + + return PartitionAnchors( + inputs=inputs, + weights=[], + biases=[], + output=outputs, + ) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index 12c722a8ab3..389526111cb 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -9,17 +9,20 @@ import itertools from collections import OrderedDict +from collections.abc import Iterable from typing import Any, Dict, List, Tuple, Type import torch from torch import fx from torch._ops import OpOverload +from torch.export import ExportedProgram from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) from torchao.quantization.pt2e import ObserverOrFakeQuantize -from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY, Quantizer def is_annotated(nodes: List[fx.Node]) -> bool: @@ -149,3 +152,29 @@ def find_sequential_partitions_aten( if _partitions_sequential(candidate): fused_partitions.append(candidate) return fused_partitions + + +def post_training_quantize( + model: ExportedProgram | fx.GraphModule, + calibration_inputs: Iterable[tuple[torch.Tensor, ...]], + quantizer: Quantizer, +) -> fx.GraphModule: + """Quantize the provided model. + + :param model: Aten model (or it's GraphModule representation) to quantize. + :param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model + input. Or an iterator over such tuples. + :param quantizer: Quantizer to use. + + :return: Quantized GraphModule. + """ + + if isinstance(model, ExportedProgram): + model = model.module() + + m = prepare_pt2e(model, quantizer) + for data in calibration_inputs: + m(*data) + m = convert_pt2e(m) + + return m diff --git a/backends/nxp/tests/TARGETS b/backends/nxp/tests/TARGETS index c8ccd5fe900..f492111aff2 100644 --- a/backends/nxp/tests/TARGETS +++ b/backends/nxp/tests/TARGETS @@ -1,4 +1,3 @@ -load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load("@fbcode_macros//build_defs:python_pytest.bzl", "python_pytest") @@ -51,9 +50,5 @@ python_pytest( "//executorch/backends/nxp:neutron_backend", ":executorch_pipeline", ":models", - ], - labels = [ - "local_only", - ci.skip_test(), - ], + ] ) diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index d209ce3ea01..a2dd8cade7b 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -23,6 +23,7 @@ from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import post_training_quantize from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -32,12 +33,12 @@ ) from torch import nn from torch.export import export -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import Quantizer -default_neutron_converter_flavor = "SDK_25_09" + +neutron_converter_flavor = "SDK_25_09" neutron_target_spec = NeutronTargetSpec( - target="imxrt700", neutron_converter_flavor=default_neutron_converter_flavor + target="imxrt700", neutron_converter_flavor=neutron_converter_flavor ) @@ -47,17 +48,6 @@ class ModelInputSpec: dtype: torch.dtype = torch.float32 -def _quantize_model( - model, quantizer, calibration_inputs: list[tuple[torch.Tensor, ...]] -): - m = prepare_pt2e(model, quantizer) - for data in calibration_inputs: - m(*data) - m = convert_pt2e(m) - - return m - - def get_random_calibration_inputs( input_spec: tuple[ModelInputSpec, ...] ) -> list[tuple[torch.Tensor, ...]]: @@ -101,15 +91,15 @@ def to_quantized_edge_program( [tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]] ] = get_random_calibration_inputs, target="imxrt700", - neutron_converter_flavor=default_neutron_converter_flavor, + neutron_converter_flavor=neutron_converter_flavor, remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 get_quantizer_fn=None, + use_neutron_for_format_conversion=True, ) -> EdgeProgramManager: _neutron_target_spec = NeutronTargetSpec(target, neutron_converter_flavor) if get_quantizer_fn is None: get_quantizer_fn = partial(_get_default_quantizer, _neutron_target_spec) - quantizer = get_quantizer_fn() calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) example_input = calibration_inputs[0] @@ -119,16 +109,17 @@ def to_quantized_edge_program( exir_program_aten = torch.export.export(model, example_input, strict=True) - exir_program_aten__module_quant = _quantize_model( - exir_program_aten.module(), - quantizer, + exir_program_aten__module_quant = post_training_quantize( + exir_program_aten, calibration_inputs, + get_quantizer_fn(), ) compile_spec = generate_neutron_compile_spec( target, operators_not_to_delegate=operators_not_to_delegate, neutron_converter_flavor=neutron_converter_flavor, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) partitioners = [ NeutronPartitioner( @@ -154,8 +145,13 @@ def to_quantized_edge_program( def to_quantized_executorch_program( model: torch.nn.Module, input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], + use_neutron_for_format_conversion: bool = True, ) -> ExecutorchProgramManager: - edge_program_manager = to_quantized_edge_program(model, input_spec) + edge_program_manager = to_quantized_edge_program( + model, + input_spec, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, + ) return edge_program_manager.to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py index 315c76a7614..96b9abfe117 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py @@ -14,9 +14,10 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) + from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -67,7 +68,9 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -80,8 +83,8 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=1.0, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py index 9c8235f7eda..a80d2014487 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py @@ -47,7 +47,9 @@ def test_adaptive_avg_pool_2d_delegated_quant_conversion( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = [str(node) for node in edge_program.graph.nodes] # Input size is a multiple of output size, can be converted to AveragePool, node is delegated @@ -91,7 +93,9 @@ def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Input size is not a multiple of output size, cannot be converted to AveragePool, node is not delegated @@ -122,7 +126,9 @@ def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 2c3107eae77..02e799723d4 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -103,7 +103,9 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index bcdbd955c71..7aed0236043 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -91,6 +92,9 @@ def test_avg_pool_2d_conversion(input_shape, padding, count_include_pad): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -145,7 +149,9 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -172,7 +178,9 @@ def test_avg_pool_2d_quant_conversion__padded(mocker): ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture the converter operators. ops = ops_spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 2d3ec7929be..590b0be6a6b 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -44,6 +44,18 @@ def forward(self, *inputs: torch.Tensor): return torch.cat(list(inputs), self.dim) +class AddCatModule(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, *inputs: torch.Tensor): + inputs = [input_ + input_ for input_ in inputs] + + return torch.cat(list(inputs), self.dim) + + class CatConvModule(torch.nn.Module): def __init__(self, dim: int, channels: int = 4): @@ -73,7 +85,7 @@ def forward(self, *inputs: torch.Tensor): ], ) def test_cat__same_shapes(dim, num_inputs, rank, mocker): - input_shape = tuple([2, 8, 8, 8, 8][-rank:]) + input_shape = tuple([8, 8, 8, 8][:rank]) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") @@ -134,11 +146,23 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): ) -@pytest.mark.parametrize("dim", [0, -4]) -@pytest.mark.parametrize("num_inputs", [2]) -def test_cat__unsupported_dim__imxrt700(dim, num_inputs): - input_shape = (2, 8, 6, 8) - +@pytest.mark.parametrize( + "dim, input_shape", + [ + pytest.param(0, (1, 8, 8, 8), id="axis = 0"), + pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."), + pytest.param(-4, (1, 8, 8, 8), id="axis = -4"), + pytest.param(1, (1, 1, 8, 8), id="axis = 1"), + pytest.param(-3, (1, 1, 8, 8), id="axis = -3"), + pytest.param(2, (1, 1, 1, 8), id="axis = 2"), + pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), + ], +) +def test_cat__unsupported__imxrt700(dim, input_shape): + """This test is conjoined with the one below (`test_cat__context_dependent__imxrt700`). + In this case, the inputs of the `cat` are NOT compute ops, so the `cat` is NOT delegated. + """ + num_inputs = 2 quantized_program = to_quantized_edge_program( CatModule(dim), [input_shape] * num_inputs, target="imxrt700" ).exported_program() @@ -152,6 +176,32 @@ def test_cat__unsupported_dim__imxrt700(dim, num_inputs): ) +@pytest.mark.parametrize( + "dim, input_shape", + [ + pytest.param(0, (1, 8, 8, 8), id="axis = 0"), + pytest.param(0, (8, 8, 8, 8), id="axis = 0, no `1s` in the shape."), + pytest.param(-4, (1, 8, 8, 8), id="axis = -4"), + pytest.param(1, (1, 1, 8, 8), id="axis = 1"), + pytest.param(-3, (1, 1, 8, 8), id="axis = -3"), + pytest.param(2, (1, 1, 1, 8), id="axis = 2"), + pytest.param(-2, (1, 1, 1, 8), id="axis = -2"), + ], +) +def test_cat__context_dependent__imxrt700(dim, input_shape): + """This test is conjoined with the one above (`test_cat__unsupported__imxrt700`). + In this case, the inputs of the `cat` are compute ops, so the `cat` is delegated. + """ + num_inputs = 2 + ep = to_quantized_edge_program( + AddCatModule(dim), [input_shape] * num_inputs, target="imxrt700" + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.cat.default]) + assert any("lowered_module" in node.name for node in ep.graph.nodes) + + @pytest.mark.parametrize( "rank, num_inputs, dim", [ diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py index d2aafb570fa..427ddaf14a5 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py @@ -113,7 +113,7 @@ def test_conv_dropout_quant(self, inplace_dropout: bool, input_shape: tuple[int] owner=EdgeProgramToIRConverter, ) as converter_spy: quantized_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 56be613a664..bd1f894001c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -7,6 +7,7 @@ import pytest import torch +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, @@ -101,6 +102,9 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index d7a59cad6d6..0fabbf615c9 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -10,6 +10,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -22,10 +23,12 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.models import Conv1dModule, Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -35,12 +38,15 @@ def reseed_model_per_test_run(): np.random.seed(23) +@pytest.mark.parametrize("bias", [False, True]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) @pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion(stride, dilation, kernel_size, mocker): +def test_conv1d_quant_conversion(bias, stride, dilation, kernel_size, mocker): input_shape = (1, 4, 16) - model = Conv1dModule(stride=stride, dilation=dilation, kernel_size=kernel_size) + model = Conv1dModule( + bias=bias, stride=stride, dilation=dilation, kernel_size=kernel_size + ) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") ops_spy = mocker.spy(ModelBuilder, "finish") @@ -142,13 +148,17 @@ def test_conv1d_quant_conversion__padded( ) # `Conv` input zp. +@pytest.mark.parametrize("bias", [False, True]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [2, 1]) @pytest.mark.parametrize("kernel_size", [(1,), (3,)]) -def test_conv1d_quant_conversion__depthwise(stride, dilation, kernel_size, mocker): +def test_conv1d_quant_conversion__depthwise( + bias, stride, dilation, kernel_size, mocker +): input_shape = (1, 4, 16) group = input_shape[1] model = Conv1dModule( + bias=bias, group=group, in_channels=group, out_channels=group, @@ -369,13 +379,35 @@ def test_conv1d_quant_conversion__depthwise__padded( (1, 32, 32, 32), id="In ch 32, out ch 32, kernel 4, padding (0, 2), dilation (1, 2)", ), + pytest.param( + Conv2dModule( + in_channels=8, out_channels=32, kernel_size=5, padding=3, bias=False + ), + (1, 8, 32, 32), + id="In ch 8, out ch 32, kernel 5, padding 3, no bias", + ), + pytest.param( + Conv2dModule( + in_channels=32, + out_channels=32, + kernel_size=3, + padding=(1, 0), + dilation=(3, 1), + bias=False, + ), + (1, 32, 35, 35), + id="In ch 32, out ch 32, kernel 3, padding (1, 0), dilation (3, 1)," + "no bias", + ), ], ) def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -395,47 +427,12 @@ def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): ) -@pytest.mark.parametrize("stride", [1, 2]) -@pytest.mark.parametrize("dilation", [1, 2]) -@pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) -def test_conv2d_conversion__depthwise(stride, dilation, kernel_shape, mocker): - input_shape = (1, 3, 12, 16) - group = input_shape[1] - edge_program = to_edge_program( - Conv2dModule( - group=group, - in_channels=group, - out_channels=group, - stride=stride, - dilation=dilation, - kernel_size=kernel_shape, - ), - input_shape, - ).exported_program() - - input_data = np.random.random(input_shape).astype(np.float32) - - spy = mocker.spy(ModelBuilder, "finish") - - convert_run_compare( - edge_program, - input_data, - tflite_input_preprocess=ToChannelLastPreprocess(), - tflite_output_preprocess=ToChannelFirstPreprocess(), - atol=4e-7, - ) - conversion_result = spy.spy_return - ops = conversion_result.sub_graphs[0].operators.vector - - assert len(ops) == 1 - assert ops[0].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D - - +@pytest.mark.parametrize("bias", [False, True]) @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [1, 2]) @pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) def test_conv2d_conversion__depthwise__quantized( - stride, dilation, kernel_shape, mocker + bias, stride, dilation, kernel_shape, mocker ): input_shape = (1, 4, 12, 12) group = input_shape[1] @@ -443,6 +440,7 @@ def test_conv2d_conversion__depthwise__quantized( edge_program = to_quantized_edge_program( Conv2dModule( + bias=bias, group=group, in_channels=group, out_channels=group, @@ -451,6 +449,7 @@ def test_conv2d_conversion__depthwise__quantized( kernel_size=kernel_shape, ), tuple(input_shape), + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -485,6 +484,9 @@ def test_conv2d_conversion__depthwise__padded(padding, mocker): tflite_input_preprocess=ToChannelLastPreprocess(), tflite_output_preprocess=ToChannelFirstPreprocess(), atol=4e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) conversion_result = spy.spy_return ops = conversion_result.sub_graphs[0].operators.vector @@ -505,6 +507,7 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): group=group, in_channels=group, out_channels=group, padding=padding ), tuple(input_shape), + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -517,3 +520,156 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): len(nodes) == 7 ) # input, Quant, lowered_module, delegate_call, getitem, Deq, output assert nodes[2].target == "lowered_module_0" + + +@pytest.mark.parametrize( + "model, input_shape", + [ + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d(64, 64, (1, 2), stride=(1, 2)), + (1, 64, 3, 12), + id="In ch 64, out ch 64, kernel (1, 2), stride (1, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d( + 16, 24, (1, 6), stride=(1, 6), output_padding=(0, 3) + ), + (1, 16, 7, 15), + id="In ch 16, out ch 24, kernel (1, 6), stride (1, 6), output_padding (0, 3)", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 40, (1, 4), stride=(1, 4), padding=(0, 1)), + (1, 16, 1, 27), + id="In ch 16, out ch 40, kernel (1, 4), stride (1, 4), padding (0, 1)", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2), padding=(0, 1)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2), padding (0, 1)", + ), + pytest.param( + torch.nn.ConvTranspose2d( + 8, 16, (1, 8), stride=(1, 4), output_padding=(0, 2) + ), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 8), stride (1, 4), output_padding (0, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 16, (1, 4), stride=(1, 2)), + (1, 16, 1, 16), + id="In ch 16, out ch 16, kernel (1, 4), stride (1, 2)", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2), bias=False), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2), no bias", + ), + pytest.param( + torch.nn.ConvTranspose2d( + 8, 16, (1, 4), stride=(1, 2), padding=(0, 1), bias=False + ), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2)," + "padding (0, 1), no bias", + ), + ], +) +def test_conv_transpose2d_conversion__quantized( + mocker, model: torch.nn.Module, input_shape +): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() + + # Make sure the `TransposeConv` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.convolution.default] + ) + assert any("lowered_module" in node.name for node in edge_program.graph.nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + convert_run_compare( + exported_program, + tflite_input_preprocess=ToChannelLastPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + input_data=input_data, + atol=1.0, + ) + + +@pytest.mark.parametrize( + "model, input_shape", + [ + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2), dilation=(1, 2)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2), " + "dilation (1, 2) - Dilation != (1, 1)", + ), + pytest.param( + torch.nn.ConvTranspose2d(6, 16, (1, 4), stride=(1, 2)), + (1, 6, 1, 16), + id="In ch 6, out ch 16, kernel (1, 4), stride (1, 2) - In channels % num_macs != 0", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 4), stride=(1, 2)), + (1, 8, 4, 16), + id="In ch 8, out ch 16, kernel (1, 4), stride (1, 2) - Out height != 1, stride width" + " != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (2, 4), stride=(1, 2), padding=(0, 1)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (2, 4), stride (1, 2), padding " + "(0, 1) - Out height != 1, stride width != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(8, 16, (1, 5), stride=(1, 4)), + (1, 8, 1, 16), + id="In ch 8, out ch 16, kernel (1, 5), stride (1, 4) - Stride width != kernel width / 2" + ", stride width != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 12, (1, 4), stride=(3, 3)), + (1, 16, 1, 16), + id="In ch 16, out ch 12, kernel (1, 4), stride (3, 3) - Out channels % num_macs != 0", + ), + pytest.param( + torch.nn.ConvTranspose2d(64, 64, (1, 4), stride=(1, 2)), + (1, 64, 3, 12), + id="In ch 64, out ch 64, kernel (1, 4), stride (1, 2) - Out height != 1, stride width" + " != kernel width", + ), + pytest.param( + torch.nn.ConvTranspose2d(16, 40, (1, 4), stride=(1, 4), padding=(0, 1)), + (1, 16, 4, 27), + id="In ch 16, out ch 40, kernel (1, 4), stride (1, 4), padding (0, 1) - Padding width " + "!= 1 and input height != 1", + ), + ], +) +def test_conv_transpose2d_non_delegated_conversion__quantized( + model: torch.nn.Module, input_shape +): + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 15 + assert ( + nodes[11].target.__name__ == "aten.convolution.default" + ) # TransposeConv not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index c4bc559817b..dad8ce6a0e3 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -42,7 +42,9 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -79,7 +81,9 @@ def test_custom_hardtanh_quant( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 50bbf100980..8b938ef7fff 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, @@ -76,6 +77,9 @@ def test_max_pool_2d_conversion(input_shape, padding): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -103,7 +107,11 @@ def test_max_pool_2d_quant_conversion(mocker, input_shape, padding): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(MaxPool2dConvModule(padding=padding), input_shape) + _ = to_quantized_edge_program( + MaxPool2dConvModule(padding=padding), + input_shape, + use_neutron_for_format_conversion=False, + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index a634416f8a7..ee69b1ea352 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -1,3 +1,8 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -8,10 +13,12 @@ from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule +from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -21,19 +28,37 @@ def reseed_model_per_test_run(): np.random.seed(23) +class MeanDimModule(torch.nn.Module): + def __init__(self, dim, keepdim): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.mean(x, dim=self.dim, keepdim=self.keepdim) + + @pytest.mark.parametrize( "input_shape, dim", [ pytest.param((1, 4, 8, 8), (-1, -2), id="Dim -1, -2."), + pytest.param((1, 4, 8, 8), (-2, -1), id="Dim -2, -1."), + pytest.param((1, 4, 8, 8), (2, 3), id="Dim 2, 3."), + pytest.param((1, 4, 8, 8), (3, 2), id="Dim 3, 2."), ], ) -def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True): - model = MeanDimConvModule(dim, keeepdim) +def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keepdim=True): + model = MeanDimConvModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + ep = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() + # Make sure the `mean.dim` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert any("lowered_module" in n.name for n in ep.graph.nodes) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -61,16 +86,16 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True) ], ) @pytest.mark.parametrize( - "keeepdim", + "keepdim", [ pytest.param(False, id="Don't keep dim."), pytest.param(True, id="Keep dim."), ], ) def test_mean_dim_linear_unsupported_quant_conversion( - mocker, input_shape, dim, keeepdim + mocker, input_shape, dim, keepdim ): - model = MeanDimLinearModule(dim, keeepdim) + model = MeanDimLinearModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") @@ -107,19 +132,21 @@ def test_mean_dim_linear_unsupported_quant_conversion( ], ) @pytest.mark.parametrize( - "keeepdim", + "keepdim", [ pytest.param(False, id="Don't keep dim."), pytest.param(True, id="Keep dim."), ], ) -def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keeepdim): - model = MeanDimConvModule(dim, keeepdim) +def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keepdim): + model = MeanDimConvModule(dim, keepdim) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated @@ -140,3 +167,93 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke tflite_output_preprocess=ToChannelFirstPreprocess(), tfl_model=tflite_flatbuffers_model, ) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((1, 2, 3, 8), (1, 2), id="Dim 1, 2."), + pytest.param((1, 2, 3, 8), (2, 1), id="Dim 2, 1."), + pytest.param((1, 2, 3, 8), (-3, -2), id="Dim -3, -2."), + pytest.param((1, 2, 3, 8), (-2, -3), id="Dim -2, -3."), + ], +) +def test_mean_dim__formatless__supported(mocker, input_shape, dim, keepdim=True): + model = MeanDimModule(dim, keepdim) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `mean.dim` was delegated. + assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert any("lowered_module" in n.name for n in ep.graph.nodes) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + convert_run_compare( + exported_program, + input_data=input_data, + tfl_model=tflite_flatbuffers_model, + atol=1, + ) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param((1, 2, 3, 8), (2, 3), id="Dim 2, 3."), + ], +) +def test_mean_dim__formatless__unsupported(input_shape, dim, keepdim=True): + model = MeanDimModule(dim, keepdim) + + ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `mean.dim` was NOT delegated. + assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert not any("lowered_module" in n.name for n in ep.graph.nodes) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param( + (1, 8, 8, 4), (1, 2), id="Dim 1, 2 (supported), channels = 4 (unsupported)." + ), + ], +) +def test_mean_dim__formatless__unsupported_channels(input_shape, dim, keepdim=True): + model = MeanDimModule(dim, keepdim) + + ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `mean.dim` was NOT delegated. + assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) + assert not any("lowered_module" in n.name for n in ep.graph.nodes) + + +@pytest.mark.parametrize( + "input_shape, dim", + [ + pytest.param( + (1, 4, 8, 8), (2, 3), id="Dim 2, 3 (supported), channels = 5 (unsupported)." + ), + ], +) +def test_mean_dim__channels_first__unsupported_channels(input_shape, dim, keepdim=True): + model = MeanDimConvModule( + dim, keepdim, out_channels=5 + ) # Only multiples of 8 (num_macs) are supported. + + # Run conversion + ep = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `mean.dim` was NOT delegated. + assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py index d25e2759cc8..57d15aefdc0 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py @@ -3,8 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import unittest + +import kgb import numpy as np -import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -13,52 +15,312 @@ from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, - ToNCHWPreprocess, - ToNHWCPreprocess, + graph_contains_any_of_ops, ) from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized from torch.export import ExportedProgram -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) +class Conv2dTransposeModule(torch.nn.Module): + def __init__(self, in_channels: int, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + self.conv = Conv2dModule( + in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1) + ) + + def forward(self, x): + x = self.conv(x) + return torch.transpose(x, self.dim0, self.dim1) + + +class Conv2dPermuteModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = self.conv(x) + return torch.permute(x, self.perm) + + +class PermuteConv2dModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = torch.permute(x, self.perm) + return self.conv(x) -class Conv2dPermuteCopyModule(torch.nn.Module): - def __init__(self, new_dims: tuple[int, ...]): +class PermuteConv2dPermuteModule(torch.nn.Module): + def __init__( + self, in_channels: int, perm1: tuple[int, ...], perm2: tuple[int, ...] + ): super().__init__() - self.new_dims = new_dims - self.conv = Conv2dModule() + self.perm1 = perm1 + self.perm2 = perm2 + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) def forward(self, x): + x = torch.permute(x, self.perm1) x = self.conv(x) - return torch.permute(x, self.new_dims) + x = torch.permute(x, self.perm2) + return x -def test_permute_copy_quant_conversion__with_bias(mocker): - input_shape = (1, 4, 8, 8) - new_dims = (0, 2, 3, 1) +class LinearPermuteModule(torch.nn.Module): + def __init__(self, in_features: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.fc = torch.nn.Linear(in_features, in_features) + + def forward(self, x): + x = self.fc(x) + return torch.permute(x, self.perm) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - # Run conversion - _ = to_quantized_edge_program(Conv2dPermuteCopyModule(new_dims), input_shape) +class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return + @parameterized.expand( + [ + ["To channel first permutation", (1, 16, 8, 8), (0, 3, 1, 2)], + ["To channel last permutation", (1, 16, 8, 8), (0, 2, 3, 1)], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_input( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = Conv2dPermuteModule(input_shape[1], perm) - # Capture converted program - edge_program: ExportedProgram = converter_spy.call_args.args[1] + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) - convert_run_compare( - edge_program, - input_data, - tfl_model=tflite_flatbuffers_model, - atol=1.0, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["To channel first permutation", (1, 8, 8, 8), (0, 3, 1, 2)], + ["To channel last permutation", (1, 8, 8, 8), (0, 2, 3, 1)], + ] ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_output( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dModule(input_shape[1], perm) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["nchw->nhwc ... nchw->nhwc", (1, 8, 8, 8), (0, 2, 3, 1), (0, 2, 3, 1)], + ["nchw->nhwc ... nhwc->nchw", (1, 8, 8, 8), (0, 2, 3, 1), (0, 3, 1, 2)], + ["nhwc->nchw ... nhwc->nchw", (1, 8, 8, 8), (0, 3, 1, 2), (0, 3, 1, 2)], + ["nhwc->nchw ... nchw->nhwc", (1, 8, 8, 8), (0, 3, 1, 2), (0, 2, 3, 1)], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_io( + self, _: str, input_shape, perm1, perm2 + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dPermuteModule(input_shape[1], perm1, perm2) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["Permutation can be replaced by reshapes", (10, 1, 8), (0, 2, 1)], + ["Permutation can be replaced by reshapes", (10, 1, 1), (2, 1, 0)], + ["Permutation is identical and can be removed", (10, 1, 8), (0, 1, 2)], + ] + ) + def test_permute_copy_conversion__from_permute_3D__quantized( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + # Run conversion + edge_program = to_quantized_edge_program( + LinearPermuteModule(input_shape[2], perm), input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["Transpose dims 1 and 2", (1, 16, 8, 8), (0, 2, 1, 3)], + ["To (2, 0, 1, 3) permutation", (1, 16, 8, 8), (2, 0, 1, 3)], + ["To (3, 1, 2, 0) permutation", (1, 16, 8, 8), (3, 1, 2, 0)], + ["To (3, 1, 0, 2) permutation", (1, 16, 8, 8), (3, 1, 0, 2)], + ] + ) + def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized( + self, _: str, input_shape, perm + ): + model = Conv2dPermuteModule(input_shape[1], perm) + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[6].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. + + @parameterized.expand( + [ + ["Transpose dims 1 and 2", (1, 16, 8, 8), 1, 2], + ["Transpose dims 2 and 3", (1, 16, 8, 8), 2, 3], + ] + ) + def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized( + self, _: str, input_shape, dim0, dim1 + ): + model = Conv2dTransposeModule(input_shape[1], dim0, dim1) + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[6].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index 8d903e3e0b5..cf0e0135ffe 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -67,7 +67,9 @@ def test_relu_with_conv_quant_conversion(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(ConvReLUModule(), input_shape) + _ = to_quantized_edge_program( + ConvReLUModule(), input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py index c5d7d4d6a38..382266e9cb1 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -33,7 +33,9 @@ def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 98566ff1ad6..336c3cc9afd 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -118,7 +118,9 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape): y_input_shape = (n, 8, h, w) # Run conversion - _ = to_quantized_edge_program(model, [x_input_shape, y_input_shape]) + _ = to_quantized_edge_program( + model, [x_input_shape, y_input_shape], use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index ca750719a32..eb5fc6600f5 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -62,7 +62,7 @@ def test_conv_tanh( ) quantized_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value exported_program: ExportedProgram = converter_spy.calls[-1].args[0] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 448a9753000..fac0a1fffee 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -12,6 +12,8 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -146,6 +148,9 @@ def test__channels_first_to_4d(mocker): input_data, tflite_input_preprocess=ToNHWCPreprocess(), atol=2.0e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) tflite_model = converter_spy.spy_return @@ -243,6 +248,7 @@ def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_ channels=input_shape[1], channels_view_out=channels_view_out ), input_shape, + use_neutron_for_format_conversion=False, ) # Capture generated model diff --git a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py index 17b040fbc3d..b5e701ab239 100644 --- a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py +++ b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py @@ -51,7 +51,10 @@ def test_remove_io_quant_ops_pass__cifarnet(): model = CifarNet().get_eager_model() input_shape = (1, 3, 32, 32) edge_program_manager = to_quantized_edge_program( - model, input_shape, remove_quant_io_ops=True + model, + input_shape, + remove_quant_io_ops=True, + use_neutron_for_format_conversion=False, ) exec_prog = edge_program_manager.to_executorch( diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index c4d9491d4a7..2bd1f2b6d77 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -494,9 +494,9 @@ def forward(self, x): class MeanDimConvModule(torch.nn.Module): - def __init__(self, dim, keepdim): + def __init__(self, dim, keepdim, out_channels=8): super().__init__() - self.conv = Conv2dModule(stride=1, padding=1) + self.conv = Conv2dModule(stride=1, padding=1, out_channels=out_channels) self.dim = dim self.keepdim = keepdim diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index fce11ce5aa2..eeb4b03d7a6 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -18,7 +18,10 @@ from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.view_copy_converter import ( ViewCopyConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executorch_pipeline import ( + neutron_target_spec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck from torch import nn @@ -98,7 +101,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]): program = torch.export.export(module, example_input, strict=True) og_module = program.module() - pm = NeutronAtenPassManager() + pm = NeutronAtenPassManager(neutron_target_spec) graph_module_out = pm(deepcopy(program.module())).graph_module # Make sure the fusion worked. @@ -133,7 +136,7 @@ def test_batch_norm_linear_fusing(bias: bool): program = torch.export.export(module, example_input, strict=True) og_module = program.module() - pm = NeutronAtenPassManager() + pm = NeutronAtenPassManager(neutron_target_spec) graph_module_out = pm(deepcopy(program.module())).graph_module # Make sure the fusion worked. diff --git a/backends/nxp/tests/test_gru_splitting.py b/backends/nxp/tests/test_gru_splitting.py index a2e9d324f69..297f9677fb2 100644 --- a/backends/nxp/tests/test_gru_splitting.py +++ b/backends/nxp/tests/test_gru_splitting.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( SplitGRUBasedOnNumLayers, ) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec @pytest.fixture(autouse=True) @@ -94,7 +95,9 @@ def test_gru_splitting__with_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) post_pass_output = [t.detach() for t in exir_program_aten(*example_input)] @@ -143,7 +146,9 @@ def test_gru_splitting__no_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) post_pass_output = [t.detach() for t in exir_program_aten(*example_input)] @@ -193,7 +198,9 @@ def test_gru_splitting__bidirectional__no_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) nodes = list(exir_program_aten.graph.nodes) @@ -239,7 +246,9 @@ def test_gru_splitting__bidirectional__with_bias(num_layers): ) # Just 1 `GRU` in the model. # Run pre-processing passes of the float32 aten dialect program. - pytorch_pass_manager = NeutronAtenPassManager([SplitGRUBasedOnNumLayers()]) + pytorch_pass_manager = NeutronAtenPassManager( + neutron_target_spec, [SplitGRUBasedOnNumLayers()] + ) pytorch_pass_manager(exir_program_aten) nodes = list(exir_program_aten.graph.nodes) diff --git a/backends/nxp/tests/test_integration.py b/backends/nxp/tests/test_integration.py index d31b22c9ce9..3bd5f3e1487 100644 --- a/backends/nxp/tests/test_integration.py +++ b/backends/nxp/tests/test_integration.py @@ -39,7 +39,9 @@ def test_conv_fc_softmax__to_executorch_program(): def test_cifarnet(): model = CifarNet().get_eager_model().eval() input_shape = (1, 3, 32, 32) - exec_prog = to_quantized_executorch_program(model, input_shape) + exec_prog = to_quantized_executorch_program( + model, input_shape, use_neutron_for_format_conversion=False + ) delegation_info = get_delegation_info(exec_prog.exported_program().graph_module) assert delegation_info.num_delegated_subgraphs == 1 diff --git a/backends/nxp/tests/test_linear_and_add_fusion.py b/backends/nxp/tests/test_linear_and_add_fusion.py index 16d3c4140a2..222d748001c 100644 --- a/backends/nxp/tests/test_linear_and_add_fusion.py +++ b/backends/nxp/tests/test_linear_and_add_fusion.py @@ -18,6 +18,7 @@ from executorch.backends.nxp.aten_passes.remove_nodes_with_known_outputs import ( RemoveNodesWithKnownOutputs, ) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops from parameterized import parameterized @@ -121,10 +122,11 @@ def test_linear_add_fusing__static__no_bias__valid_shape( original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -167,10 +169,11 @@ def test_linear_add_fusing__static__no_bias__invalid_shape( original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -209,10 +212,11 @@ def test_linear_add_fusing__static__bias__valid_shape( original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -253,10 +257,11 @@ def test_linear_add_fusing__static__no_bias__reverse_order(self): original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -295,10 +300,11 @@ def test_linear_add_fusing__static__bias__reverse_order(self): original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -340,10 +346,11 @@ def test_linear_add_fusing__static__alpha__no_bias(self): original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -381,10 +388,11 @@ def test_linear_add_fusing__static__alpha__bias(self): original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -424,10 +432,11 @@ def test_linear_add_fusing__static__alpha__reversed_add_inputs(self): original_module = program.module() modified_module = NeutronAtenPassManager( + neutron_target_spec, [ RemoveNodesWithKnownOutputs(), # Make the added tensor static. FuseLinearAndAddPass(), - ] + ], )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. @@ -474,9 +483,9 @@ def test_linear_add_fusing__dynamic__no_bias__valid_shape( program = torch.export.export(module, example_input, strict=True) original_module = program.module() - modified_module = NeutronAtenPassManager([FuseLinearAndAddPass()])( - deepcopy(program.module()) - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. original_nodes = list(original_module.graph.nodes) @@ -513,9 +522,9 @@ def test_linear_add_fusing__dynamic__no_bias__invalid_shape( program = torch.export.export(module, example_input, strict=True) original_module = program.module() - modified_module = NeutronAtenPassManager([FuseLinearAndAddPass()])( - deepcopy(program.module()) - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. original_nodes = list(original_module.graph.nodes) @@ -550,9 +559,9 @@ def test_linear_add_fusing__dynamic__bias__valid_shape( program = torch.export.export(module, example_input, strict=True) original_module = program.module() - modified_module = NeutronAtenPassManager([FuseLinearAndAddPass()])( - deepcopy(program.module()) - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. original_nodes = list(original_module.graph.nodes) @@ -584,9 +593,9 @@ def test_linear_add_fusing__dynamic__reverse_order(self): program = torch.export.export(module, example_input, strict=True) original_module = program.module() - modified_module = NeutronAtenPassManager([FuseLinearAndAddPass()])( - deepcopy(program.module()) - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. original_nodes = list(original_module.graph.nodes) @@ -618,9 +627,9 @@ def test_linear_add_fusing__dynamic__alpha(self): program = torch.export.export(module, example_input, strict=True) original_module = program.module() - modified_module = NeutronAtenPassManager([FuseLinearAndAddPass()])( - deepcopy(program.module()) - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [FuseLinearAndAddPass()] + )(deepcopy(program.module())).graph_module # Make sure the module wasn't broken. original_nodes = list(original_module.graph.nodes) diff --git a/backends/nxp/tests/test_move_activation_before_concatenation.py b/backends/nxp/tests/test_move_activation_before_concatenation.py new file mode 100644 index 00000000000..cede3e41994 --- /dev/null +++ b/backends/nxp/tests/test_move_activation_before_concatenation.py @@ -0,0 +1,959 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import unittest + +import kgb +import numpy as np +import torch +from executorch.backends.nxp.aten_passes.move_activation_before_concat import ( + MoveActivationBeforeConcat, +) +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + NeutronAtenPassManager, +) +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import post_training_quantize +from executorch.backends.nxp.tests.executorch_pipeline import ( + get_random_calibration_inputs, + neutron_target_spec, + to_model_input_spec, + to_quantized_edge_program, +) +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + graph_contains_any_of_ops, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.backends.nxp.tests.models import get_activation +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized +from torch import nn +from torch.export import ExportedProgram +from torch.fx import GraphModule + +concat_cluster_ops = [ + exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.cat.default, +] + + +class ConvConcatActivationModule(torch.nn.Module): + def __init__(self, activation: str, inplace: bool, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + in_channels, + (3, 3), + padding=1, + ) + + self.activation = get_activation(activation, inplace) + self.eval() + + def forward(self, x): + x1 = self.conv(x) + x2 = self.conv(x) + x = torch.cat((x1, x2), dim=1) + return self.activation(x) + + +class LinearConcatActivationModule(nn.Module): + def __init__( + self, activation: str, inplace: bool, in_channels: int, mode: str = "linear" + ): + super().__init__() + self.mode = mode.lower() + assert self.mode in [ + "linear", + "addmm", + "mm", + ], "Mode must be 'linear', 'addmm', or 'mm'" + + if self.mode == "linear": + self.linear = nn.Linear(in_channels, in_channels) + else: + # Manual weight and bias for addmm/mm. + self.weight = nn.Parameter(torch.empty(in_channels, in_channels)) + self.bias = nn.Parameter(torch.empty(in_channels)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + self.activation = get_activation(activation, inplace) + self.eval() + + def forward(self, x): + x1, x2 = None, None + + if self.mode == "linear": + x1 = self.linear(x) + x2 = self.linear(x) + if self.mode == "addmm": + x1 = torch.addmm(self.bias, x, self.weight) + x2 = torch.addmm(self.bias, x, self.weight) + elif self.mode == "mm": + x1 = torch.mm(x, self.weight) + x2 = torch.mm(x, self.weight) + + x = torch.cat((x1, x2), dim=1) + return self.activation(x) + + +class ConvActivationConcatModule(torch.nn.Module): + def __init__( + self, + activation1: str, + activation2: str, + act1_inplace: bool, + act2_inplace: bool, + in_channels: int, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + in_channels, + (3, 3), + padding=1, + ) + + self.activation1 = get_activation(activation1, act1_inplace) + self.activation2 = get_activation(activation2, act2_inplace) + self.eval() + + def forward(self, x): + x1 = self.conv(x) + x1 = self.activation1(x1) + x2 = self.conv(x) + x2 = self.activation2(x2) + return torch.cat((x1, x2), dim=1) + + +class LinearActivationConcatModule(torch.nn.Module): + def __init__( + self, + activation1: str, + activation2: str, + act1_inplace: bool, + act2_inplace: bool, + in_channels: int, + ): + super().__init__() + self.linear = nn.Linear(in_channels, in_channels) + + self.activation1 = get_activation(activation1, act1_inplace) + self.activation2 = get_activation(activation2, act2_inplace) + self.eval() + + def forward(self, x): + x1 = self.linear(x) + x1 = self.activation1(x1) + x2 = self.linear(x) + x2 = self.activation2(x2) + return torch.cat((x1, x2), dim=1) + + +class TestMoveActivationBeforeConcat(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests. + + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat__conv(self, activation, inplace): + input_shape = (1, 3, 8, 8) + model = ConvConcatActivationModule( + activation=activation, inplace=inplace, in_channels=3 + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 8 + cat_node = nodes[5] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 9 + cat_node = nodes[7] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[8].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = post_training_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + ) + + # Check convolution and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 26 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[18] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[19] + ) + ) + assert ( + nodes[20].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat__linear(self, activation, inplace): + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="linear" + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 8 + cat_node = nodes[5] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 9 + cat_node = nodes[7] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[8].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = post_training_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + ) + + # Check linear and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 22 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[10] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[11] + ) + ) + assert ( + nodes[12].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat__addmm(self, activation, inplace): + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="addmm" + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 8 + cat_node = nodes[5] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[6] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 9 + cat_node = nodes[7] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[8].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = post_training_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + ) + + # Check addmm and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 22 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[10] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[11] + ) + ) + assert ( + nodes[12].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat__mm(self, activation, inplace): + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="mm" + ) + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export( + model, example_input, strict=True + ).module() + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + nodes = list(exir_program_aten.graph.nodes) + assert len(nodes) == 7 + cat_node = nodes[4] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[5] + ) + ) + + # Apply the optimization. + NeutronAtenPassManager( + neutron_target_spec, + [MoveActivationBeforeConcat(neutron_target_spec)], + )(exir_program_aten) + + nodes = list(exir_program_aten.graph.nodes) + + # Make sure the optimization was applied. + assert len(nodes) == 8 + cat_node = nodes[6] + assert cat_node.target == torch.ops.aten.cat.default + assert all( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + input_node + ) + and len(input_node.users) == 1 + for input_node in cat_node.all_input_nodes + ) + assert nodes[7].target == "output" + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert np.allclose(outputs_before[0], outputs_after[0]) + + # Run pre-processing passes of the float32 aten dialect program. + neutron_aten_pass_manager = NeutronAtenPassManager(neutron_target_spec) + neutron_aten_pass_manager(exir_program_aten) # All passes by default. + + exir_program_aten_quant = post_training_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + ) + + # Check mm and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 19 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[7] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[8] + ) + ) + assert ( + nodes[9].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[11] + ) + assert ( + neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[12] + ) + ) + assert ( + nodes[13].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat_quantization__conv( + self, activation, inplace + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8, 8, 8) + model = ConvConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8 + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat_quantization__linear( + self, activation, inplace + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="linear" + ) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat_quantization__addmm( + self, activation, inplace + ): + torch.manual_seed(23) + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="addmm" + ) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["relu", True], + ["relu", False], + ["relu6", True], + ["relu6", False], + ["tanh", True], + ["tanh", False], + ["sigmoid", False], + ] + ) + def test_move_activation_before_concat_quantization__mm(self, activation, inplace): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + input_shape = (1, 8) + model = LinearConcatActivationModule( + activation=activation, inplace=inplace, in_channels=8, mode="mm" + ) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=concat_cluster_ops + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + ) + + @parameterized.expand( + [ + ["relu", "relu", True, False], + ["relu6", "relu6", False, True], + ["tanh", "tanh", True, False], + ["sigmoid", "sigmoid", False, True], + ["relu", "relu_hardtanh", True, True], + ] + ) + def test_concat_cluster_quantization__conv( + self, activation1, activation2, act1_inplace, act2_inplace + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + with kgb.spy_on( + post_training_quantize, call_original=True + ) as quantizer_spy: + input_shape = (1, 8, 8, 8) + model = ConvActivationConcatModule( + activation1, activation2, act1_inplace, act2_inplace, in_channels=8 + ) + + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=concat_cluster_ops, + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[ + -1 + ].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + exir_program_aten_quant: GraphModule = quantizer_spy.calls[ + -1 + ].return_value + + # Check convolution and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 26 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[18] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[19] + ) + assert ( + nodes[20].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + input_data = ( + np.random.random(input_shape).astype(np.float32) * 50 + ).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) + + @parameterized.expand( + [ + ["relu", "relu", True, False], + ["relu6", "relu6", False, True], + ["tanh", "tanh", True, False], + ["sigmoid", "sigmoid", False, True], + ["relu", "relu_hardtanh", True, True], + ] + ) + def test_concat_cluster_quantization__linear( + self, activation1, activation2, act1_inplace, act2_inplace + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, + call_original=True, + owner=EdgeProgramToIRConverter, + ) as converter_spy: + with kgb.spy_on( + post_training_quantize, call_original=True + ) as quantizer_spy: + input_shape = (1, 8) + model = LinearActivationConcatModule( + activation1, activation2, act1_inplace, act2_inplace, in_channels=8 + ) + + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure that all nodes were delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, + ops=concat_cluster_ops, + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[ + -1 + ].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + exir_program_aten_quant: GraphModule = quantizer_spy.calls[ + -1 + ].return_value + + # Check linear and activation are in same QDQ cluster. + nodes = list(exir_program_aten_quant.graph.nodes) + assert len(nodes) == 22 + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[10] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[11] + ) + assert ( + nodes[12].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + assert neutron_target_spec.neutron_target_info.is_fusable_conv_or_linear__aten( + nodes[14] + ) + assert neutron_target_spec.neutron_target_info.is_supported_fused_activation__aten( + nodes[15] + ) + assert ( + nodes[16].target + == torch.ops.quantized_decomposed.quantize_per_tensor.default + ) + + input_data = ( + np.random.random(input_shape).astype(np.float32) * 50 + ).astype(np.int8) + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), + ) diff --git a/backends/nxp/tests/test_neutron_backend.py b/backends/nxp/tests/test_neutron_backend.py index c9917651fbd..08c66b22585 100644 --- a/backends/nxp/tests/test_neutron_backend.py +++ b/backends/nxp/tests/test_neutron_backend.py @@ -21,7 +21,9 @@ def test_neutron_backend__single_conv_model(): def test_neutron_backend__single_conv_model__payload_header_channels_last(): edge_program_manager = to_quantized_edge_program( - Conv2dModule(bias=False), (1, 4, 32, 32) + Conv2dModule(bias=False), + (1, 4, 32, 32), + use_neutron_for_format_conversion=False, ) payload = ( edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes diff --git a/backends/nxp/tests/test_neutron_backend_executor.py b/backends/nxp/tests/test_neutron_backend_executor.py index 3503403311f..6daf1570374 100644 --- a/backends/nxp/tests/test_neutron_backend_executor.py +++ b/backends/nxp/tests/test_neutron_backend_executor.py @@ -11,10 +11,13 @@ ) from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOptions import BuiltinOptions from executorch.backends.nxp.backend.ir.lib.tflite.Model import Model +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.nxp_backend import PayloadComposer from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, EdgeProgramExecutor, + graph_contains_any_of_ops, TFLiteExecutor, ToNHWCPreprocess, ) @@ -108,3 +111,217 @@ def test_conv_fc__lowered_program_and_tflite_output_match(mocker): input_data=input_data, tflite_input_preprocess=ToNHWCPreprocess(), ) + + +def test_delegating_format_related_transpose_operators__unsupported_shapes(mocker): + # This test focuses on the case when Neutron would not support the inserted Transpose operators, so they are not + # inserted, so the runtime will permute the data. + + # Make sure none of the dimensions are multiples of `num_macs` (8), for proper testing. + model = Conv2dModule(in_channels=3, out_channels=3, padding=1, stride=1) + input_shape = (1, 3, 3, 3) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops are NOT in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 1]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_case(mocker): + # Make sure the output channels (channels for the trailing Transpose), and the last input dimension (channels for + # the leading Transpose) are multiples of `num_macs``. + + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, out_channels=num_macs, padding=1, stride=1 + ) + input_shape = (1, num_macs, num_macs, num_macs) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops ARE in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 4 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(3).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `0` means `channels_last`, which means the runtime will NOT transpose the data. + assert all(payload_header[3:5] == [0, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_output__unsupported_input( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=num_macs, # The output `Transpose` will be supported. + padding=1, + stride=1, + ) + input_shape = (1, num_macs, num_macs, 3) # The input `Transpose` is not supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 3 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_input__unsupported_output( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=3, # The output `Transpose` will NOT be supported. + stride=1, + ) + input_shape = (1, num_macs, 3, num_macs) # The input `Transpose` is supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [0, 1]) # [, ] diff --git a/backends/nxp/tests/test_per_channel_conversion.py b/backends/nxp/tests/test_per_channel_conversion.py index b988fce470d..62cbef9e151 100644 --- a/backends/nxp/tests/test_per_channel_conversion.py +++ b/backends/nxp/tests/test_per_channel_conversion.py @@ -126,6 +126,7 @@ def test_per_channel_convolution(self): get_quantizer_fn=lambda: NeutronAtenQuantizer( Conv2dPatternPerChannel(is_per_channel=True), static_qconfig ), + use_neutron_for_format_conversion=False, ) tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index d76fbaf460d..85736039d26 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -376,7 +376,7 @@ def test_quantizers_order_invariance(): ) def test_quantizer__linear_w_activation(mocker, activation, inplace): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantizer_spy = mocker.spy(executorch_pipeline, "_quantize_model") + quantizer_spy = mocker.spy(executorch_pipeline, "post_training_quantize") input_shape = (1, 4) model = models.LinearActivationModule( @@ -432,7 +432,7 @@ def test_quantizer__linear_w_activation(mocker, activation, inplace): ) def test_quantizer__addmm_w_activation(mocker, activation, inplace): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantizer_spy = mocker.spy(executorch_pipeline, "_quantize_model") + quantizer_spy = mocker.spy(executorch_pipeline, "post_training_quantize") input_shape = (1, 4) model = models.LinearActivationModule( @@ -485,7 +485,7 @@ def test_quantizer__addmm_w_activation(mocker, activation, inplace): ) def test_quantizer__mm_w_activation(mocker, activation, inplace): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantizer_spy = mocker.spy(executorch_pipeline, "_quantize_model") + quantizer_spy = mocker.spy(executorch_pipeline, "post_training_quantize") input_shape = (1, 4) model = models.LinearActivationModule( @@ -538,7 +538,7 @@ def test_quantizer__mm_w_activation(mocker, activation, inplace): ) def test_quantizer__conv_w_activation(mocker, activation, inplace): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantizer_spy = mocker.spy(executorch_pipeline, "_quantize_model") + quantizer_spy = mocker.spy(executorch_pipeline, "post_training_quantize") input_shape = (1, 4, 8, 8) model = models.ConvActivationModule( diff --git a/backends/nxp/tests/test_removing_dead_code.py b/backends/nxp/tests/test_removing_dead_code.py index 00cb6775b3c..18d2f1d698e 100644 --- a/backends/nxp/tests/test_removing_dead_code.py +++ b/backends/nxp/tests/test_removing_dead_code.py @@ -10,10 +10,8 @@ import torch from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer -from executorch.backends.nxp.tests.executorch_pipeline import ( - _quantize_model, - neutron_target_spec, -) +from executorch.backends.nxp.quantizer.utils import post_training_quantize +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops @@ -55,8 +53,8 @@ def test_removing_dead_code(self): # The `NeutronQuantizer` should remove the dead code in the `transform_for_annotation()` method. quantizer = NeutronQuantizer(neutron_target_spec) - exir_program_aten_quant = _quantize_model( - exir_program_aten.module(), quantizer, [example_inputs] + exir_program_aten_quant = post_training_quantize( + exir_program_aten, [example_inputs], quantizer ) # Make sure the is no `add` operation in the graph anymore. diff --git a/backends/nxp/tests/test_removing_nodes_with_known_outputs.py b/backends/nxp/tests/test_removing_nodes_with_known_outputs.py index 8f5549c8526..0c496356791 100644 --- a/backends/nxp/tests/test_removing_nodes_with_known_outputs.py +++ b/backends/nxp/tests/test_removing_nodes_with_known_outputs.py @@ -17,6 +17,7 @@ from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import ( SplitGRUBasedOnNumLayers, ) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops from parameterized import parameterized from torch import nn @@ -57,7 +58,9 @@ def test_removing_nodes__zeros(self): outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] # Apply the optimization. - NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten) + NeutronAtenPassManager(neutron_target_spec, [RemoveNodesWithKnownOutputs()])( + exir_program_aten + ) # Make sure the `aten.zeros` is no longer in the model. assert not graph_contains_any_of_ops( @@ -81,7 +84,9 @@ def test_removing_nodes__split(self, num_layers): exir_program_aten = torch.export.export(model, example_input).module() # Apply the pass to split the `aten.gru.input` into multiple instances, and add a `split` node. - NeutronAtenPassManager([SplitGRUBasedOnNumLayers()])(exir_program_aten) + NeutronAtenPassManager(neutron_target_spec, [SplitGRUBasedOnNumLayers()])( + exir_program_aten + ) # Make sure the `aten.zeros` and `torch.split` are in the model. assert graph_contains_any_of_ops( @@ -93,7 +98,9 @@ def test_removing_nodes__split(self, num_layers): outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] # Apply the optimization. - NeutronAtenPassManager([RemoveNodesWithKnownOutputs()])(exir_program_aten) + NeutronAtenPassManager(neutron_target_spec, [RemoveNodesWithKnownOutputs()])( + exir_program_aten + ) # Make sure the `aten.zeros` and `torch.split` are no longer in the model. assert not graph_contains_any_of_ops( diff --git a/backends/nxp/tests/test_split_group_convolution.py b/backends/nxp/tests/test_split_group_convolution.py index 8b2d5723dbb..f5dfcff1fde 100644 --- a/backends/nxp/tests/test_split_group_convolution.py +++ b/backends/nxp/tests/test_split_group_convolution.py @@ -18,8 +18,8 @@ from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import post_training_quantize from executorch.backends.nxp.tests.executorch_pipeline import ( - _quantize_model, get_random_calibration_inputs, neutron_target_spec, to_model_input_spec, @@ -41,10 +41,11 @@ def _quantize_and_lower_module( module: GraphModule, input_shape: tuple[int, ...], target="imxrt700" ) -> EdgeProgramManager: calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape)) - quantizer = NeutronQuantizer(neutron_target_spec) - exir_program_aten__module_quant = _quantize_model( - module, quantizer, calibration_inputs + exir_program_aten__module_quant = post_training_quantize( + module, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), ) edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) @@ -88,9 +89,9 @@ def test_split_group_convolution__2d(self, _, input_shape: list[int], group: int graph_module = torch.export.export(module, example_input, strict=True).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( - graph_module - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [SplitGroupConvolution()] + )(graph_module).graph_module # Make sure the fusion worked. original_nodes = list(original_module.graph.nodes) @@ -145,9 +146,9 @@ def test_split_group_convolution__1d(self, _, input_shape: list[int], group: int graph_module = torch.export.export(module, example_input).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( - graph_module - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [SplitGroupConvolution()] + )(graph_module).graph_module # Make sure the fusion worked. original_nodes = list(original_module.graph.nodes) @@ -199,9 +200,9 @@ def test_split_group_convolution__3d(self, _, input_shape: list[int], group: int graph_module = torch.export.export(module, example_input).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager([SplitGroupConvolution()])( - graph_module - ).graph_module + modified_module = NeutronAtenPassManager( + neutron_target_spec, [SplitGroupConvolution()] + )(graph_module).graph_module # Verify that the pass has NOT made any changes, as it is disabled for 3D convolution. original_nodes = list(original_module.graph.nodes) @@ -233,7 +234,7 @@ def test_split_group_convolution__applied_by_default(self): graph_module = torch.export.export(module, example_input).module() original_module = deepcopy(graph_module) - modified_module = NeutronAtenPassManager()( + modified_module = NeutronAtenPassManager(neutron_target_spec)( graph_module ).graph_module # Default passes. diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt index 07166b92ea2..cc7957dfdbe 100644 --- a/backends/qualcomm/CMakeLists.txt +++ b/backends/qualcomm/CMakeLists.txt @@ -23,6 +23,47 @@ get_filename_component( _common_include_directories "${EXECUTORCH_SOURCE_DIR}/.." ABSOLUTE ) +# We only download QNN SDK when we build pip wheel for ExecuTorch. Please don't +# change this code unless you know what you are doing. +if(EXECUTORCH_BUILD_WHEEL_DO_NOT_USE) + set(_qnn_default_sdk_dir "${CMAKE_CURRENT_BINARY_DIR}/sdk/qnn") + + if(EXISTS "${_qnn_default_sdk_dir}" AND EXISTS "${_qnn_default_sdk_dir}/lib") + message(STATUS "Found cached Qualcomm SDK at ${_qnn_default_sdk_dir}") + set(QNN_SDK_ROOT + ${_qnn_default_sdk_dir} + CACHE PATH "Qualcomm SDK root directory" FORCE + ) + else() + message(STATUS "Downloading Qualcomm SDK") + execute_process( + COMMAND + ${PYTHON_EXECUTABLE} + ${EXECUTORCH_SOURCE_DIR}/backends/qualcomm/scripts/download_qnn_sdk.py + --dst-folder ${_qnn_default_sdk_dir} --print-sdk-path + WORKING_DIRECTORY ${EXECUTORCH_SOURCE_DIR} + RESULT_VARIABLE _qnn_sdk_download_result + OUTPUT_VARIABLE _qnn_sdk_download_output + ERROR_VARIABLE _qnn_sdk_download_error + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(NOT _qnn_sdk_download_result EQUAL 0 OR _qnn_sdk_download_output + STREQUAL "" + ) + message( + FATAL_ERROR + "Failed to download Qualcomm SDK. stdout: ${_qnn_sdk_download_output}\n" + "stderr: ${_qnn_sdk_download_error}" + ) + endif() + set(QNN_SDK_ROOT + ${_qnn_sdk_download_output} + CACHE PATH "Qualcomm SDK root directory" FORCE + ) + endif() + set(ENV{QNN_SDK_ROOT} ${QNN_SDK_ROOT}) +endif() + if(NOT DEFINED QNN_SDK_ROOT) message( FATAL_ERROR @@ -214,7 +255,9 @@ add_subdirectory( install( TARGETS qnn_executorch_backend EXPORT ExecuTorchTargets - DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm ) # QNN pybind @@ -275,4 +318,12 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") ${QNN_EXECUTORCH_ROOT_DIR}/aot/python ${CMAKE_CURRENT_BINARY_DIR}/qnn_executorch/python ) + + install( + TARGETS PyQnnManagerAdaptor PyQnnWrapperAdaptor + LIBRARY + DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python + RUNTIME + DESTINATION ${CMAKE_INSTALL_LIBDIR}/executorch/backends/qualcomm/python + ) endif() diff --git a/backends/qualcomm/README.md b/backends/qualcomm/README.md index 89c7cf07b25..fa82c38b2fd 100644 --- a/backends/qualcomm/README.md +++ b/backends/qualcomm/README.md @@ -22,6 +22,7 @@ Please check `generate_qnn_executorch_compiler_spec()` in - Snapdragon 8 Gen 3 - Snapdragon 8 Elite - SA8295 +- SA8255 - SSG2115P - SSG2125P - SXR1230P diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index a96c5b21d42..691ba1607ff 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -43,9 +43,12 @@ class LayoutTransform(ExportPass): layout_sensitive_ops = { exir_ops.edge.aten.adaptive_avg_pool2d.default, exir_ops.edge.aten._adaptive_avg_pool3d.default, + exir_ops.edge.aten.adaptive_max_pool2d.default, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.avg_pool3d.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.grid_sampler_2d.default, + exir_ops.edge.aten.grid_sampler_3d.default, exir_ops.edge.aten.instance_norm.default, exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 54cfae6591c..2f1c2d54828 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -2,15 +2,19 @@ Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently. ## Sections -* [References](#references) -* [Getting Started](#getting-started) - * [Identify Unsupported Operator](#identify-unsupported-operator) - * [Check Operator Spec](#check-operator-spec) - * [Implementation](#implementation) - * [Quantizer Annotation](#quantizer-annotation) -* [Operator Support Status](#operator-support-status) -* [Issues](#issues) -* [Pull Requests](#pull-requests) +- [Contribution for More Operators](#contribution-for-more-operators) + - [Sections](#sections) + - [References](#references) + - [Qualcomm AI Engine Direct](#qualcomm-ai-engine-direct) + - [PyTorch](#pytorch) + - [Getting Started](#getting-started) + - [Identify Unsupported Operator](#identify-unsupported-operator) + - [Check Operator Spec](#check-operator-spec) + - [Implementation](#implementation) + - [Quantizer Annotation](#quantizer-annotation) + - [Operator Support Status](#operator-support-status) + - [Issues](#issues) + - [Pull Requests](#pull-requests) ## References ### Qualcomm AI Engine Direct @@ -365,7 +369,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 92/116 Enabled | +| Operators | HTP - 94/116 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -431,7 +435,7 @@ Please help update following table if you are contributing new operators: | Gelu | ✓ | | GetSparseIndices | ✗ | | GetSparseValues | ✗ | -| GridSample | ✗ | +| GridSample | ✓ | | GroupNorm | ✓ | | HardSwish | ✓ | | InstanceNorm | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 4bf0ea7e210..e982985477d 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -8,6 +8,7 @@ node_visitor, op_abs, op_adaptive_avg_pool2d, + op_adaptive_max_pool2d, op_add, op_amax, op_amin, @@ -44,6 +45,7 @@ op_gather, op_ge, op_gelu, + op_grid_sampler_2d, op_group_norm, op_gt, op_hardsigmoid, @@ -114,6 +116,7 @@ node_visitor, op_abs, op_adaptive_avg_pool2d, + op_adaptive_max_pool2d, op_add, op_amax, op_amin, @@ -150,6 +153,7 @@ op_gather, op_ge, op_gelu, + op_grid_sampler_2d, op_group_norm, op_gt, op_hardswish, diff --git a/backends/qualcomm/builders/op_adaptive_max_pool2d.py b/backends/qualcomm/builders/op_adaptive_max_pool2d.py new file mode 100644 index 00000000000..0db8f42ceb2 --- /dev/null +++ b/backends/qualcomm/builders/op_adaptive_max_pool2d.py @@ -0,0 +1,151 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpPoolMax2d, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AdaptiveMaxPool2D(NodeVisitor): + target = ["aten.adaptive_max_pool2d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + users = list(node.users.keys()) + for user in users: + if user.target.__name__ == "getitem": + getitem_index = user.args[1] + if getitem_index != 0: + warnings.warn( + f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}", + stacklevel=1, + ) + return + + if len(node.args) > 2: + warnings.warn( + "[QNN Delegate Op Builder]: The return_indices is not supported, fallback op", + stacklevel=1, + ) + return + + input_height = input_tensor.shape[1] + input_width = input_tensor.shape[2] + # output cases + out_wh = cast(List[int], node.args[1]) + if len(out_wh) == 1: + output_height = node.args[1][0] + output_width = node.args[1][0] + else: + output_height = node.args[1][0] + output_width = node.args[1][1] + if output_height is None: + output_height = input_height + if output_width is None: + output_width = input_width + # NOTE: Here we need not to emphasize on mode, cuz the output shape is decided by user. + mode = OpPoolMax2d.RoundingMode.FLOOR + + # floor division + stride_height = input_height // output_height + filter_height = input_height - (output_height - 1) * stride_height + stride_width = input_width // output_width + filter_width = input_width - (output_width - 1) * stride_width + + filter = [filter_height, filter_width] + filter_shape = [len(filter)] + + stride = [stride_height, stride_width] + stride_shape = [len(stride)] + + padding = [0, 0] + padding_shape = [len(padding), len(padding)] + + out_tensor = self.get_tensor(node, node, 0) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + adaptive_max_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolMax2d.op_name, + ) + + adaptive_max_pool2d_op.AddInputTensors([input_tensor_wrapper]) + adaptive_max_pool2d_op.AddOutputTensors([output_tensor_wrapper]) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_filter_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_shape), + filter_shape, + np.array( + filter, + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_stride, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddTensorParam( + OpPoolMax2d.param_pad_amount, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + [[padding[0], padding[0]], [padding[1], padding[1]]], + dtype=np.uint32, + ), + True, + ) + + adaptive_max_pool2d_op.AddScalarParam( + OpPoolMax2d.param_rounding_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(mode)}, + ) + + return adaptive_max_pool2d_op diff --git a/backends/qualcomm/builders/op_grid_sampler_2d.py b/backends/qualcomm/builders/op_grid_sampler_2d.py new file mode 100644 index 00000000000..6b6e7bf8610 --- /dev/null +++ b/backends/qualcomm/builders/op_grid_sampler_2d.py @@ -0,0 +1,162 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch + +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_DTYPE + +from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpGridSample, OpTranspose, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class GridSample(NodeVisitor): + target = ["aten.grid_sampler_2d.default", "aten.grid_sampler_3d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + grid_sample_op_list = [] + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + grid_node = self.get_node(node.args[1]) + grid_tensor = self.get_tensor(grid_node, node) + grid_tensor_wrapper = self.define_tensor( + grid_node, + node, + grid_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + input_shape = input_node.meta["val"].shape + input_rank = len(input_shape) + if input_rank not in [4, 5]: + warnings.warn( + "[QNN Delegate Op Builder]: The shape is not supported, fallback op", + stacklevel=1, + ) + return + + # About this operator, in ATen, the layout of input_tensor and of grid_tensor are not identical. + # But in HW they are all NHWC or NDHWC. So, we make shape transformation again. + if input_rank == 4: + dims_shape_back = (0, 3, 1, 2) + elif input_rank == 5: + dims_shape_back = (0, 4, 1, 2, 3) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Not support rank {input_rank}, fallback op", + stacklevel=1, + ) + return + + grid_quant_encoding, grid_quant_configs = self.get_quant_encoding_conf( + grid_node, node + ) + grid_dtype = ( + QNN_TENSOR_TYPE_MAP[grid_tensor.dtype] + if grid_quant_encoding + == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED + else QNN_QUANT_TYPE_MAP[ + ( + torch.uint16 + if grid_quant_configs[QCOM_DTYPE] == torch.int32 + else grid_quant_configs[QCOM_DTYPE] + ) + ] + ) + # transpose + permute_output_tensor = grid_tensor.permute(dims=dims_shape_back) + transpose_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_transpose", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=grid_dtype, + quant_encoding=grid_quant_encoding, + quant_configs=grid_quant_configs, + dims=permute_output_tensor.size(), + tensor=permute_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + + permute_order = cast(List[int], dims_shape_back) + permute_order_shape = [len(permute_order)] + transpose_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTranspose.op_name, + ) + transpose_op.AddInputTensors([grid_tensor_wrapper]) + transpose_op.AddOutputTensors([transpose_output_tensor_wrapper]) + transpose_op.AddTensorParam( + OpTranspose.param_perm, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(permute_order_shape), + permute_order_shape, + np.array(permute_order, dtype=np.uint32), + True, + ) + grid_sample_op_list.append(transpose_op) + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + align_corners = node.args[4] if len(node.args) > 4 else False + padding_mode = node.args[3] if len(node.args) > 3 else 0 + interpo_mode = node.args[2] if len(node.args) > 2 else 0 + + grid_sample_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGridSample.op_name, + ) + grid_sample_op.AddInputTensors( + [input_tensor_wrapper, transpose_output_tensor_wrapper] + ) + grid_sample_op.AddOutputTensors([output_tensor_wrapper]) + grid_sample_op.AddScalarParam( + OpGridSample.param_align_corners, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: align_corners}, + ) + grid_sample_op.AddScalarParam( + OpGridSample.param_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(interpo_mode)}, + ) + grid_sample_op.AddScalarParam( + OpGridSample.param_padding_mode, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(padding_mode)}, + ) + grid_sample_op_list.append(grid_sample_op) + return grid_sample_op_list diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 19c63015f64..ecc221885dc 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -304,6 +304,24 @@ class OpGather: param_axis: str = "axis" +class OpGridSample: + op_name: str = "GridSample" + param_align_corners: str = "align_corners" + param_mode: str = "mode" + param_padding_mode: str = "padding_mode" + + @unique + class Mode(IntEnum): + BILINAR = 0 + NEAREST = 1 + + @unique + class PaddingMode(IntEnum): + ZEROS = 0 + BORDER = 1 + REFLECTION = 2 + + @dataclass(init=False, frozen=True) class OpGatherElements: op_name: str = "GatherElements" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 76f22552c8d..2447e6a06c6 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -19,7 +19,6 @@ ] to_be_implemented_operator = [ - exir_ops.edge.aten.adaptive_max_pool2d.default, exir_ops.edge.aten.adaptive_max_pool3d.default, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.log10.default, diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 97e0b4bd109..7c45845f516 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -16,11 +16,19 @@ def generate_qnn_executorch_option( compiler_specs: List[CompileSpec], ) -> bytes: + qnn_compile_spec_buffer = None + for compiler_spec in compiler_specs: if compiler_spec.key == QCOM_QNN_COMPILE_SPEC: qnn_compile_spec_buffer = compiler_spec.value else: raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}") + + if qnn_compile_spec_buffer is None: + raise ValueError( + f"QNN compile spec (key={QCOM_QNN_COMPILE_SPEC}) not found in compiler_specs" + ) + return qnn_compile_spec_buffer diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 8b59de3bd4e..7df29d431ea 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -520,6 +520,29 @@ def annotate_full(node: Node, quantization_config: QuantizationConfig) -> None: ) +@register_annotator([torch.ops.aten.grid_sampler.default]) +def annotate_grid_sampler(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + input_act_qsec = quantization_config.input_activation + output_act_qsec = quantization_config.output_activation + + input_qspec_map = {} + input_act0 = node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qsec + + input_act1 = node.args[1] + if isinstance(input_act1, Node): + input_qspec_map[input_act1] = input_act_qsec + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qsec, + _annotated=True, + ) + + @register_annotator( [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default] ) @@ -561,6 +584,27 @@ def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.adaptive_max_pool2d.default]) +def annotate_adaptive_max_pool2d( + node: Node, quantization_config: QuantizationConfig +) -> None: + if _is_annotated([node]): + return + input_act_qsec = quantization_config.input_activation + output_act_qsec = quantization_config.output_activation + + input_qspec_map = {} + input_act0 = node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qsec + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qsec, + _annotated=True, + ) + + @register_annotator( [ torch.ops.aten.adaptive_avg_pool1d.default, diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index c592ad64da6..e34630538d0 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum, unique + from typing import Sequence import torch @@ -17,7 +17,6 @@ get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, - get_qat_per_channel_quant_config, QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -32,36 +31,6 @@ ) -def annotate_down_proj( - gm: torch.fx.GraphModule, quantization_config: QuantizationConfig -): - for node in gm.graph.nodes: - if ( - node.target == torch.ops.aten.conv2d.default - and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"]) - and node.args[0].target == torch.ops.aten.mul.Tensor - ): - input_qspec_map = {} - input_qspec_map[node.args[0]] = quantization_config.input_activation - input_qspec_map[node.args[1]] = quantization_config.weight - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - -@unique -class StaticLLMQuantConfig(Enum): - """ - Layer namespace configuration for Qualcomm's static LLaMA quantization. - """ - - wq_sha = "wq_sha" # Query weight (single head) - wk_sha = "wk_sha" # Key weight (single head) - wv_sha = "wv_sha" # Value weight (single head) - - def annotate_eurobert(gm: torch.fx.GraphModule): """ QNN does not support int32 -> signed 16bit quant @@ -123,49 +92,6 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule): break -def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None: - """ - This function is for static LLM models. - This function will annotate the last conv(linear), which is the lm_head, as 16a8w. - """ - - def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: - input_qspec_map = {} - input_act = node.args[0] - input_spec = quantization_config.input_activation - input_qspec_map[input_act] = input_spec - - weight = node.args[1] - input_qspec_map[weight] = quantization_config.weight - - if len(node.args) > 2 and isinstance(node.args[2], Node): - input_qspec_map[node.args[2]] = quantization_config.bias(node) - - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - if is_qat: - quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - else: - quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - for node in gm.graph.nodes: - if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: - if "nn_module_stack" in node.meta: - module_values_list = list(node.meta["nn_module_stack"].values()) - full_qualified_name = module_values_list[-1][0] - if full_qualified_name == "output.conv": - annotate_conv2d( - node, quantization_config=quantization_config_16a8w_per_channel - ) - - def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): for node in gm.graph.nodes: if node.op == "output": @@ -200,48 +126,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): ) -def annotate_qkv_proj_sha( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - qkv_tags: set[StaticLLMQuantConfig], -): - """ - Annotates QKV projection layers in a GraphModule for quantization, - specifically layers defined in StaticLLMQuantConfig. - - Args: - qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers - (e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in - StaticLLMQuantConfig are allowed. - - Raises: - ValueError: If any tag in `qkv_tags` is not among the allowed enum members. - """ - - # Get all valid tags from the StaticLLMQuantConfig enum - allowed_tags = set(StaticLLMQuantConfig) - invalid_tags = qkv_tags - allowed_tags - if invalid_tags: - raise ValueError( - f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}" - ) - - for node in gm.graph.nodes: - if node.target == torch.ops.aten.conv2d.default and any( - tag.value in node.meta["stack_trace"] for tag in qkv_tags - ): - input_qspec_map = {} - input_qspec_map[node.args[0]] = quantization_config.input_activation - input_qspec_map[node.args[1]] = quantization_config.weight - if len(node.args) > 2 and isinstance(node.args[2], Node): - input_qspec_map[node.args[2]] = quantization_config.bias(node) - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - def annotate_kv_8bit( # noqa: C901 gm: torch.fx.GraphModule, is_qat=False, @@ -262,7 +146,6 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): input_act = node.args[0] input_spec = quantization_config.input_activation input_qspec_map[input_act] = input_spec - input_act1 = node.args[1] input_spec1 = quantization_config.weight input_qspec_map[input_act1] = input_spec1 diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index e22d5b30fa7..593eb77961a 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -136,6 +136,61 @@ def get_8a8w_qnn_ptq_config( return quantization_config +def get_8a4w_qnn_ptq_config( + act_symmetric: bool = True, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + # 4 bits quantization only supports specific ops. def get_16a4w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, diff --git a/backends/qualcomm/quantizer/quant_recipe.py b/backends/qualcomm/quantizer/quant_recipe.py new file mode 100644 index 00000000000..92b9757e1fb --- /dev/null +++ b/backends/qualcomm/quantizer/quant_recipe.py @@ -0,0 +1,402 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from abc import ABC, abstractmethod +from enum import IntEnum, unique +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, + QuantizationConfig, +) +from tabulate import tabulate +from torch._ops import OpOverload +from torchao.quantization.pt2e import UniformQuantizationObserverBase + +from .annotators import OP_ANNOTATOR + + +def extract_node_metadata_mapping(node: torch.fx.Node): + deepest_module = None + + if node.op == "call_function" and "nn_module_stack" in node.meta: + deepest_module = list(node.meta["nn_module_stack"].values())[-1][0] + + return deepest_module + + +@unique +class QuantGranularity(IntEnum): + """ + Defines the quantization granularity levels: + - PER_TENSOR: single scale offset for entire tensor. + - PER_CHANNEL: independent scale/offset per channel within tensor. + - PER_BLOCK: independent scale/offset per block within tensor. + """ + + PER_TENSOR = 0 + PER_CHANNEL = 1 + PER_BLOCK = 2 + + +class QuantizationStrategy(ABC): + """ + Abstract base class for strategies that assign quantization config to FX graph nodes. + + Each strategy defines how to match nodes (e.g., by operator target, module stack pattern) + and provides a corresponding quantization config when a match occurs. + + Attributes: + quant_dtype (QuantDtype): Data type for quantization (e.g., 16a8w, 16a4w). + is_qat (bool): Whether the strategy applies QAT (True) or PTQ (False). + granularity (QuantGranularity): Quantization granularity (PER_TENSOR, PER_CHANNEL, PER_BLOCK). + act_observer (UniformQuantizationObserverBase): Observer class for activation quantization. + extra_kwargs (Dict): Additional configuration parameters (e.g., block size). + note (str): Developer notes or comments. + priority (int): Priority for resolving conflicts among multiple strategies. + + Abstract Methods: + _matches(node): Return True if the node matches this strategy's criteria. + """ + + def __init__( + self, + quant_dtype: QuantDtype, + is_qat: bool, + granularity: QuantGranularity, + act_observer: UniformQuantizationObserverBase, + extra_kwargs: Dict, + note: str, + priority: int, + ): + self.quant_dtype = quant_dtype + self.is_qat = is_qat + self.granularity = granularity + self.act_observer = act_observer + self.extra_kwargs = extra_kwargs + self.note = note + self.priority = priority + + self.quant_config = ModuleQConfig( + quant_dtype=self.quant_dtype, + is_qat=self.is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + act_observer=self.act_observer, + ) + + @abstractmethod + def _matches(self, node: torch.fx.Node) -> bool: + pass + + def get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: + op: OpOverload = node.target + + if not self._matches(node): + return None + + if self.granularity == QuantGranularity.PER_TENSOR: + return self.quant_config.quant_config + elif self.granularity == QuantGranularity.PER_CHANNEL: + ch_axis = self.quant_config.use_per_channel_weight_quant_ops.get(op) + assert ( + ch_axis is not None + and len(self.quant_config.per_channel_quant_config_list) > ch_axis + ), f"Unsupported per channel quantization axis: {ch_axis}, please increase the range of per_channel_quant_config_list" + return self.quant_config.per_channel_quant_config_list[ch_axis] + elif self.granularity == QuantGranularity.PER_BLOCK: + ch_axis = self.quant_config.op_axis_dict.get(op) + assert ( + ch_axis is not None + and len(self.quant_config.per_block_quant_config_list) > ch_axis + ), f"Unsupported per block quantization axis: {ch_axis}, please increase the range of per_block_quant_config_list" + config = self.quant_config.per_block_quant_config_list[ch_axis] + config.block_size = self.extra_kwargs["block_size"] + return config + else: + raise ValueError( + f"Unsupported quantization granularity: {self.granularity}. " + f"Supported values: {[granularity.name for granularity in QuantGranularity]}" + ) + + +class ByNodeTarget(QuantizationStrategy): + """ + Strategy that assigns quantization config to nodes based on their op target. + Useful for applying quantization to specific operations such as `aten.conv2d` or `aten.linear`. + + Attributes: + targets (Set[OpOverload]): Set of op overloads to match against node targets. + """ + + def __init__( + self, + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + targets: Set[OpOverload], + ): + super().__init__( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + ) + self.targets = targets + + def _matches(self, node: torch.fx.Node) -> bool: + # Matching: A node matches if its `node.target` is in the `targets` set. + return node.target in self.targets + + +class ByNameRegex(QuantizationStrategy): + """ + Strategy that assigns quantization config to nodes whose module stack matches given regex patterns. + Useful for targeting layers by name patterns (e.g., "layers.[0-3].feed_forward" or "layers.*.attention") in the module hierarchy. + + Attributes: + patterns (Set[str]): Set of regex patterns to match against module stack paths. + """ + + def __init__( + self, + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + patterns: Set[str], + ): + super().__init__( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs, + note, + priority, + ) + self.patterns = patterns + + def _matches(self, node: torch.fx.Node) -> bool: + # Matching: A node matches if its `nn_module_stack` metadata contains a module path that matches any regex pattern. + if node.op == "call_function" and "nn_module_stack" in node.meta: + for module_stack, _ in list(node.meta["nn_module_stack"].values())[::-1]: + if module_stack and any( + re.search(p, module_stack) for p in self.patterns + ): + return True + return False + + +class QuantRecipe: + """ + A QuantRecipe builder for defining quantization strategies to an FX GraphModule. + + QuantRecipe manages a collection of strategies (e.g., by operator target or regex pattern) + and applies them to nodes in an FX graph to produce fine-grained quantization annotations. + + Attributes: + verbose (bool): If True, prints a summary after annotation. + custom_quant_annotations (Sequence[Callable]): Custom annotation functions applied after strategies. + + _strategies (List[QuantizationStrategy]): Registered quantization strategies. + _pending_annotate_nodes (Dict[torch.fx.Node, Tuple[QuantizationConfig, QuantizationStrategy]]): + Internal mapping of nodes to their resolved quantization config and strategy. + """ + + def __init__( + self, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + extra_kwargs: Optional[dict] = None, + verbose: bool = False, + ): + """ + Initialize a QuantRecipe with a default quantization strategy. + + Args: + quant_dtype (QuantDtype): Data type for quantization (e.g., int8, int4). + is_qat (bool): Whether to apply QAT (True) or PTQ (False). + act_observer (UniformQuantizationObserverBase): Observer class for activation quantization. + granularity (QuantGranularity): Quantization granularity (PER_TENSOR, PER_CHANNEL, PER_BLOCK). + note (str): Optional description for the default strategy. + extra_kwargs (dict, optional): Additional parameters (e.g., block size, group size). + verbose (bool): If True, prints a summary table after annotation. + """ + + self.verbose = verbose + self.custom_quant_annotations: Sequence[Callable] = [] + + self._strategies: List[QuantizationStrategy] = [] + self._pending_annotate_nodes: Dict[ + torch.fx.Node, Tuple[QuantizationConfig, QuantizationStrategy] + ] = {} + self._default_strategy = ByNodeTarget( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority=1, + targets=QnnQuantizer.SUPPORTED_OPS, + ) + + def _annotate_custom_annotation(self, gm: torch.fx.GraphModule) -> None: + for annotation_func in self.custom_quant_annotations: + annotation_func(gm) + + def annotate(self, graph_module: torch.fx.GraphModule): + # Sort node level strategies by (priority, insertion index). + # Higher priority value comes first; if priorities are equal, original insertion order is preserved. + strategies: List[QuantizationStrategy] = [ + strategy + for _, strategy in sorted( + enumerate(self._strategies), + key=lambda x: (x[1].priority, x[0]), + reverse=True, + ) + ] + # Ensure the default strategy is appended last + strategies.append(self._default_strategy) + + for node in graph_module.graph.nodes: + for strategy in strategies: + if isinstance(node.target, str) or node in self._pending_annotate_nodes: + continue + + if quant_config := strategy.get_quant_config(node): + self._pending_annotate_nodes[node] = (quant_config, strategy) + + if self.verbose: + print(self.summary()) + + for node in graph_module.graph.nodes: + if isinstance(node.target, str): + continue + if node not in self._pending_annotate_nodes: + print(f"No quant config is implemented for op, {node.target}") + continue + + OP_ANNOTATOR[node.target](node, self._pending_annotate_nodes[node][0]) + + # custom annotation + self._annotate_custom_annotation(graph_module) + + def add_node_target( + self, + targets, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + priority: int = 1, + extra_kwargs: Optional[dict] = None, + ): + self._strategies.append( + ByNodeTarget( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority, + targets, + ), + ) + return self + + def add_regex( + self, + regex, + quant_dtype, + is_qat, + act_observer: UniformQuantizationObserverBase, + granularity: QuantGranularity, + note: str = "", + priority: int = 1, + extra_kwargs: Optional[dict] = None, + ): + """ + Add a quantization strategy targeting nodes whose module stack matches given regex patterns. + + Args: + regex (Iterable[str]): Regex patterns to match module stack paths. + quant_dtype (QuantDtype): Data type for quantization. + is_qat (bool): Whether to apply QAT or PTQ. + act_observer (UniformQuantizationObserverBase): Observer for activation quantization. + granularity (QuantGranularity): Tensor/channel/block granularity. + note (str): Optional description for the strategy. + priority (int): Strategy priority (higher value = higher precedence). + extra_kwargs (dict, optional): Additional parameters for the strategy. + """ + self._strategies.append( + ByNameRegex( + quant_dtype, + is_qat, + granularity, + act_observer, + extra_kwargs or {}, + note, + priority, + regex, + ), + ) + return self + + def summary(self, max_rows: int = -1): + if not self._pending_annotate_nodes: + return None + + headers = [ + "module_stack", + "op_target", + "quantize", + "act_observer", + "granularity", + "note", + "extra_kwargs", + ] + rows = [] + for i, (node, (_, strategy)) in enumerate(self._pending_annotate_nodes.items()): + if max_rows > 0 and i >= max_rows: + break + + row = [ + extract_node_metadata_mapping(node), + node.target, + f"{strategy.quant_dtype.name}/{'QAT' if strategy.is_qat else 'PTQ'}", + strategy.act_observer.__name__, + strategy.granularity.name, + strategy.note, + strategy.extra_kwargs, + ] + rows.append(row) + + if max_rows > 0 and len(self._pending_annotate_nodes) > max_rows: + rows.append(["..."] * len(headers)) + + return tabulate(rows, headers=headers, tablefmt="grid") diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 4d0f1098a62..9ca9a7dad6c 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -24,6 +24,7 @@ get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, get_16a8w_qnn_qat_config, + get_8a4w_qnn_ptq_config, get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_block_quant_config, @@ -44,6 +45,7 @@ "get_16a16w_qnn_ptq_config", "get_8a8w_qnn_ptq_config", "get_8a8w_qnn_qat_config", + "get_8a4w_qnn_ptq_config", "get_16a4w_qnn_qat_config", "get_ptq_per_block_quant_config", ] @@ -60,6 +62,7 @@ class QuantDtype(IntEnum): use_16a4w = 2 use_16a4w_block = 3 use_8a8w = 4 + use_8a4w = 5 QUANT_CONFIG_DICT = { @@ -109,6 +112,15 @@ class QuantDtype(IntEnum): partial(get_ptq_per_channel_quant_config), None, ), + (QuantDtype.use_8a4w, False): ( + get_8a4w_qnn_ptq_config, + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint8, + weight_dtype=torch.int4, + ), + None, + ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, @@ -242,10 +254,12 @@ def __init__(self): self.submodule_qconfig_list: List[ Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig] ] = [] + self.block_size_map = {} self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() + self.recipe = None def _annotate(self, gm: GraphModule) -> None: """ @@ -348,14 +362,20 @@ def annotate(self, model: GraphModule) -> GraphModule: """ Annotates GraphModule during prepare_pt2e. + If a recipe is provided, it will be used to annotate the model. + Otherwise, fallback to the default annotation flow. + Args: model (GraphModule): The FX GraphModule to annotate. Returns: GraphModule: The annotated model. """ - self._annotate(model) - self._annotate_custom_annotation(model) + if self.recipe: + self.recipe.annotate(model) + else: + self._annotate(model) + self._annotate_custom_annotation(model) return model @@ -389,10 +409,10 @@ def set_default_quant_config( """ self.default_quant_config = ModuleQConfig( quant_dtype, - is_qat, - is_conv_per_channel, - is_linear_per_channel, - act_observer, + is_qat=is_qat, + is_conv_per_channel=is_conv_per_channel, + is_linear_per_channel=is_linear_per_channel, + act_observer=act_observer, ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: diff --git a/backends/qualcomm/scripts/download_qnn_sdk.py b/backends/qualcomm/scripts/download_qnn_sdk.py index 747524a0e5b..5524adf8988 100644 --- a/backends/qualcomm/scripts/download_qnn_sdk.py +++ b/backends/qualcomm/scripts/download_qnn_sdk.py @@ -1,4 +1,4 @@ -# Add these imports for additional logging +import argparse import ctypes import logging import os @@ -14,6 +14,8 @@ import zipfile from typing import Dict, List, Optional, Tuple +import requests +from requests.adapters import HTTPAdapter, Retry logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -115,12 +117,34 @@ def _atomic_download(url: str, dest: pathlib.Path): def _download_archive(url: str, archive_path: pathlib.Path) -> bool: - """Download archive from URL with progress reporting.""" + """Robust streaming download with retries.""" + logger.debug("Archive will be saved to: %s", archive_path) + session = requests.Session() + retries = Retry( + total=5, + backoff_factor=1.0, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["GET"], + ) + session.mount("https://", HTTPAdapter(max_retries=retries)) + try: - urllib.request.urlretrieve(url, archive_path, _make_report_progress()) + with session.get(url, stream=True) as r: + r.raise_for_status() + + downloaded = 0 + chunk_size = 1024 * 1024 # 1MB + + with open(archive_path, "wb") as f: + for chunk in r.iter_content(chunk_size): + if chunk: + f.write(chunk) + downloaded += len(chunk) + logger.info("Download completed!") + except Exception as e: logger.exception("Error during download: %s", e) return False @@ -131,27 +155,8 @@ def _download_archive(url: str, archive_path: pathlib.Path) -> bool: elif not archive_path.exists(): logger.error("File was not downloaded!") return False - return True - - -def _make_report_progress(): - """Return a callback to report download progress.""" - last_reported = 0 - - def report_progress(block_num, block_size, total_size): - nonlocal last_reported - try: - downloaded = block_num * block_size - percent = downloaded / total_size * 100 if total_size else 100.0 - except Exception: - percent, downloaded, total_size = 0.0, block_num * block_size, 0 - if percent - last_reported >= 20 or percent >= 100: - logger.info( - "Downloaded: %d/%d bytes (%.2f%%)", downloaded, total_size, percent - ) - last_reported = percent - return report_progress + return True def _extract_archive( @@ -592,3 +597,46 @@ def install_qnn_sdk() -> bool: # libc++ and QNN SDK setup return _ensure_libcxx_stack() and _ensure_qnn_sdk_lib() + + +def main(argv: Optional[List[str]] = None) -> int: + parser = argparse.ArgumentParser( + description="Helper utility for Qualcomm SDK staging." + ) + parser.add_argument( + "--dst-folder", + type=pathlib.Path, + default=SDK_DIR, + help="Destination directory for the Qualcomm SDK.", + ) + parser.add_argument( + "--print-sdk-path", + action="store_true", + help="Print the resolved Qualcomm SDK path to stdout.", + ) + parser.add_argument( + "--install-sdk", + action="store_true", + help="Ensure the SDK and runtime libraries are staged and loaded.", + ) + args = parser.parse_args(argv) + + logging.basicConfig(level=logging.INFO) + + sdk_path: Optional[pathlib.Path] + if args.install_sdk: + if not install_qnn_sdk(): + return 1 + sdk_path = pathlib.Path(os.environ.get("QNN_SDK_ROOT", args.dst_folder)) + else: + sdk_path = _download_qnn_sdk(dst_folder=args.dst_folder) + if sdk_path is None: + return 1 + + if args.print_sdk_path and sdk_path is not None: + print(sdk_path) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/backends/qualcomm/scripts/install_qnn_sdk.sh b/backends/qualcomm/scripts/install_qnn_sdk.sh index d9d79d3d0db..5bc0f7eeb1d 100644 --- a/backends/qualcomm/scripts/install_qnn_sdk.sh +++ b/backends/qualcomm/scripts/install_qnn_sdk.sh @@ -27,7 +27,7 @@ setup_android_ndk() { mkdir -p "${NDK_INSTALL_DIR}" NDK_ZIP="android-ndk-${NDK_VERSION}-linux.zip" - curl --retry 3 -Lo "/tmp/${NDK_ZIP}" "https://dl.google.com/android/repository/${NDK_ZIP}" + curl --retry 3 --retry-delay 5 --retry-connrefused --continue-at - -Lo "/tmp/${NDK_ZIP}" "https://dl.google.com/android/repository/${NDK_ZIP}" unzip -q "/tmp/${NDK_ZIP}" -d "${NDK_INSTALL_DIR}" mv "${NDK_INSTALL_DIR}/android-ndk-${NDK_VERSION}" "${NDK_INSTALL_DIR}/ndk" diff --git a/backends/qualcomm/serialization/qc_compiler_spec.fbs b/backends/qualcomm/serialization/qc_compiler_spec.fbs index 3000c9e1187..85affe3464d 100644 --- a/backends/qualcomm/serialization/qc_compiler_spec.fbs +++ b/backends/qualcomm/serialization/qc_compiler_spec.fbs @@ -34,6 +34,7 @@ table HtpInfo { enum QcomChipset: int { UNKNOWN_SM = 0, SA8295 = 39, + SM8350 = 35, SM8450 = 36, SM8475 = 42, SM8550 = 43, @@ -46,6 +47,7 @@ enum QcomChipset: int { SXR2330P = 75, QCS9100 = 77, SAR2230P = 95, + SA8255 = 52, } /// Indicate the information of the specified SoC. diff --git a/backends/qualcomm/serialization/qc_schema.py b/backends/qualcomm/serialization/qc_schema.py index bcbd53a235e..c188c555c41 100644 --- a/backends/qualcomm/serialization/qc_schema.py +++ b/backends/qualcomm/serialization/qc_schema.py @@ -40,6 +40,7 @@ class HtpInfo: class QcomChipset(IntEnum): UNKNOWN_SM = 0 SA8295 = 39 # v68 + SM8350 = 35 # v68 SM8450 = 36 # v69 SM8475 = 42 # v69 SM8550 = 43 # v73 @@ -52,6 +53,7 @@ class QcomChipset(IntEnum): SXR2330P = 75 # v79 QCS9100 = 77 # v73 SAR2230P = 95 # v81 + SA8255 = 52 # v73 @dataclass @@ -62,9 +64,11 @@ class SocInfo: _soc_info_table = { QcomChipset.SA8295: SocInfo(QcomChipset.SA8295, HtpInfo(HtpArch.V68, 8)), + QcomChipset.SM8350: SocInfo(QcomChipset.SM8350, HtpInfo(HtpArch.V68, 4)), QcomChipset.SM8450: SocInfo(QcomChipset.SM8450, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8475: SocInfo(QcomChipset.SM8475, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8550: SocInfo(QcomChipset.SM8550, HtpInfo(HtpArch.V73, 8)), + QcomChipset.SA8255: SocInfo(QcomChipset.SA8255, HtpInfo(HtpArch.V73, 8)), QcomChipset.SM8650: SocInfo(QcomChipset.SM8650, HtpInfo(HtpArch.V75, 8)), QcomChipset.SM8750: SocInfo(QcomChipset.SM8750, HtpInfo(HtpArch.V79, 8)), QcomChipset.SSG2115P: SocInfo(QcomChipset.SSG2115P, HtpInfo(HtpArch.V73, 2)), diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 1674c99175a..cdd0c194fe3 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -41,6 +41,19 @@ def forward(self, x): return torch.abs(x) +class AdaptiveMaxPool2D(torch.nn.Module): + def __init__(self, output_size, return_indices=False): + super().__init__() + self.output_size = output_size + self.return_indices = return_indices + + def forward(self, x): + adaptive_max_pool = torch.nn.AdaptiveMaxPool2d( + self.output_size, self.return_indices + ) + return adaptive_max_pool(x) + + class AdaptiveAvgPool1D(torch.nn.Module): def __init__(self): super().__init__() @@ -1098,6 +1111,20 @@ def forward(self, x): return x > self.constant +class GridSample(torch.nn.Module): + def __init__(self, mode, padding_mode, align_corners): + super().__init__() + self.mode = mode + self.align_corners = align_corners + self.padding_mode = padding_mode + + def forward(self, x, grid): + grid_sample = torch.nn.functional.grid_sample( + x, grid, self.mode, self.padding_mode, self.align_corners + ) + return grid_sample + + class GroupNorm(torch.nn.Module): def __init__(self, bias=True): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 082c1ea5a08..c878edd53c9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -10,6 +10,7 @@ import sys import tempfile import unittest +from dataclasses import dataclass from functools import partial from multiprocessing.connection import Listener from pathlib import Path @@ -134,6 +135,21 @@ def test_qnn_backend_adaptive_avg_pool3d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_adaptive_max_pool2d(self): + sample_input = (torch.randn(1, 512, 24, 24),) + # NOTE: Currently, we only support the return_indices is False and default is False. + # NOTE: Currently, we only support the case mod(in_w, out_w)=0 and mod(in_h, out_h)=0. + modules = [ + AdaptiveMaxPool2D((1, 1), False), # noqa: F405 + AdaptiveMaxPool2D((4, 4)), # noqa: F405 + AdaptiveMaxPool2D((24, 24)), # noqa: F405 + AdaptiveMaxPool2D((None, 4)), # noqa: F405 + AdaptiveMaxPool2D((12, None)), # noqa: F405 + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -858,6 +874,29 @@ def test_qnn_backend_gelu(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_grid_sampler(self): + # NOTE: The grid_sampler 3d version is not supported in fp16. + modes = ["bilinear", "nearest"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [False, True] + grid_samples = [ + GridSample(mode, pad, align) # noqa: F405 + for mode, pad, align in itertools.product( + modes, padding_modes, align_corners + ) + ] + sample_inputs = [ + ( + torch.randn(1, 12, 14, 14), + torch.randn(1, 3, 3, 2), + ), # for grid_sampler 2d + ] + + for j in range(len(sample_inputs)): + for i, module in enumerate(grid_samples): + with self.subTest(i=i, j=j, module=module): + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_glu(self): modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] sample_input = (torch.randn(2, 5, 1, 4),) @@ -2091,6 +2130,22 @@ def test_qnn_backend_adaptive_avg_pool3d(self): module = self.get_qdq_module(module, sample_inputs[j]) self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_adaptive_max_pool2d(self): + sample_input = (torch.randn(1, 512, 24, 24),) + # NOTE: Currently, we only support the return_indices is False and default is False. + # NOTE: Currently, we only support the case mod(in_w, out_w)=0 and mod(in_h, out_h)=0. + modules = [ + AdaptiveMaxPool2D((1, 1), False), # noqa: F405 + AdaptiveMaxPool2D((4, 4)), # noqa: F405 + AdaptiveMaxPool2D((24, 24)), # noqa: F405 + AdaptiveMaxPool2D((None, 4)), # noqa: F405 + AdaptiveMaxPool2D((12, None)), # noqa: F405 + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + module_one = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module_one, sample_input) + def test_qnn_backend_alias(self): module = Alias() # noqa: F405 sample_input = (torch.randn(1, 10),) @@ -2909,6 +2964,34 @@ def test_qnn_backend_gelu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_grid_sampler(self): + modes = ["bilinear", "nearest"] + padding_modes = ["zeros", "border", "reflection"] + align_corners = [False, True] + grid_samples = [ + GridSample(mode, pad, align) # noqa: F405 + for mode, pad, align in itertools.product( + modes, padding_modes, align_corners + ) + ] + sample_inputs = [ + ( + torch.randn(1, 12, 14, 14), + torch.randn(1, 3, 3, 2), + ), # for grid_sampler 2d + ( + torch.randn(1, 15, 9, 17, 33), + torch.randn(1, 7, 8, 9, 3), + ), # for grid_sampler 3d + ] + for j in range(len(sample_inputs)): + for i, module in enumerate(grid_samples): + with self.subTest(i=i, j=j, module=module): + module = self.get_qdq_module( + module, sample_inputs[j], quant_dtype=QuantDtype.use_16a16w + ) + self.lower_module_and_test_output(module, sample_inputs[j]) + def test_qnn_backend_glu(self): modules = [torch.nn.GLU(), torch.nn.GLU(dim=0)] sample_input = (torch.randn(2, 5, 1, 4),) @@ -5762,72 +5845,74 @@ def test_qnn_backend_seq_mse(self): class TestExampleLLMScript(TestQNN): - def test_codegen2_1b(self): - if not self.required_envs(): - self.skipTest("missing required envs") - prompt = "def hello_world():" - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - prompt, - "--temperature", - "0", - "--decoder_model", - "codegen2_1b", - "--model_mode", - "kv", - "--max_seq_len", - "128", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + @dataclass(frozen=True) + class LlmSpecs: + SM8650: float + SM8750: float + ppl: float + pte_size: float - golden_start_with = "def hello_world():" - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - if not self.compile_only: - model_out = msg["result"][0] - self.assertTrue( - model_out.startswith(golden_start_with), - f"Expected Output: {golden_start_with}. Actual Output: {model_out}", - ) - if not self.enable_x86_64: - pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 1_200_000_000) # 1200MB - if not self.compile_only and not self.enable_x86_64: - self.assertGreaterEqual(msg["inference_speed"], 60) + # TODO: refactor to support different backends + def setUp(self): + self.llm_specs = { + "gemma-2b": TestExampleLLMScript.LlmSpecs( + SM8650=32, SM8750=36, ppl=35, pte_size=2_700_000_000 + ), # 2.7 GB + "gemma3-1b": TestExampleLLMScript.LlmSpecs( + SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000 + ), # 1.2 GB + "glm-1_5b": TestExampleLLMScript.LlmSpecs( + SM8650=42, SM8750=52, ppl=21, pte_size=1_100_000_000 + ), # 1.1 GB + "phi_4_mini": TestExampleLLMScript.LlmSpecs( + SM8650=14, SM8750=19, ppl=12, pte_size=4_000_000_000 + ), # 4GB + "llama3_2-1b_instruct": TestExampleLLMScript.LlmSpecs( + SM8650=37, SM8750=45, ppl=16, pte_size=1_500_000_000 + ), # 1.5 GB + "llama3_2-3b_instruct": TestExampleLLMScript.LlmSpecs( + SM8650=21, SM8750=26, ppl=11, pte_size=2_800_000_000 + ), # 2.8 GB + "qwen2_5-0_5b": TestExampleLLMScript.LlmSpecs( + SM8650=115, SM8750=155, ppl=15, pte_size=600_000_000 + ), # 600 MB + "qwen2_5-1_5b": TestExampleLLMScript.LlmSpecs( + SM8650=38, SM8750=47, ppl=10, pte_size=1_500_000_000 + ), # 1.5 GB + "qwen3-0_6b": TestExampleLLMScript.LlmSpecs( + SM8650=47, SM8750=68, ppl=21, pte_size=700_000_000 + ), # 700 MB + "qwen3-1_7b": TestExampleLLMScript.LlmSpecs( + SM8650=28, SM8750=34, ppl=15, pte_size=1_800_000_000 + ), # 1.8 GB + "smollm2_135m": TestExampleLLMScript.LlmSpecs( + SM8650=214, SM8750=260, ppl=23, pte_size=210_000_000 + ), # 210 MB + "smollm3-3b": TestExampleLLMScript.LlmSpecs( + SM8650=23, SM8750=28, ppl=10, pte_size=2_600_000_000 + ), # 2.6 GB + } - def test_static_gemma_2b(self): - if not self.required_envs(): + def test_static_llm_model(self): + if not self.required_envs([self.model_name]): self.skipTest("missing required envs") + assert ( + self.model_name in self.llm_specs + ), f"Unable to find {self.model_name} under model_specs." - prompt = "My favourite condiment is " + is_llama_model = self.model_name in { + "llama3_2-1b_instruct", + "llama3_2-3b_instruct", + } + if is_llama_model: + assert ( + self.llama_artifacts is not None + ), "Please provide path to llama artifacts" + + prompt = ( + "I would like to learn python, could you teach me with a simple example?" + ) cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", @@ -5843,18 +5928,33 @@ def test_static_gemma_2b(self): str(self.port), "--prompt", f"{prompt}", + "--temperature", + "0", "--decoder_model", - "gemma-2b", + self.model_name, "--model_mode", "kv", "--max_seq_len", "1024", - "--eval_perplexity", + "--run_lm_eval", "--tasks", "wikitext", "--limit", "1", ] + + if is_llama_model: + cmds.extend( + [ + "--checkpoint", + f"{self.llama_artifacts}/consolidated.00.pth", + "--params", + f"{self.llama_artifacts}/params.json", + "--tokenizer_model", + f"{self.llama_artifacts}/tokenizer.model", + ] + ) + if self.compile_only: cmds.extend(["--compile_only"]) elif self.device: @@ -5874,19 +5974,30 @@ def test_static_gemma_2b(self): if "Error" in msg: self.fail(msg["Error"]) else: - inference_speed_ref = {"SM8650": 32, "SM8750": 36} - self.assertLessEqual(msg["wiki_ppl"], 35) - self.assertLessEqual(msg["pte_size"], 2_700_000_000) # 2.7GB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) + llm_spec = self.llm_specs[self.model_name] + pte_size = msg["pte_size"] + self.assertLessEqual(pte_size, llm_spec.pte_size) + print(f"Model Name: {self.model_name}\nTarget Device: {self.model}") + print(f"PTE Size: {pte_size} bytes") + if not self.compile_only: + ppl = msg["wiki_ppl"] + print(f"PPL: {ppl}") + self.assertLessEqual(ppl, llm_spec.ppl) + if not self.enable_x86_64 and hasattr(llm_spec, self.model): + device_inference_speed = msg["inference_speed"] + expected_inference_speed = getattr(llm_spec, self.model) + print( + f"Prompt Evaluation: {device_inference_speed} tokens/second" + ) + self.assertGreaterEqual( + device_inference_speed, expected_inference_speed + ) - def test_static_gemma3_1b(self): + def test_codegen2_1b(self): if not self.required_envs(): self.skipTest("missing required envs") - prompt = "My favourite condiment is " + prompt = "def hello_world():" cmds = [ "python", f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", @@ -5901,23 +6012,15 @@ def test_static_gemma3_1b(self): "--port", str(self.port), "--prompt", - f"{prompt}", - "--ptq", - "16a4w_block", + prompt, "--temperature", "0", "--decoder_model", - "gemma3-1b", + "codegen2_1b", "--model_mode", "kv", "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - "--enable_masked_softmax", + "128", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -5930,6 +6033,7 @@ def test_static_gemma3_1b(self): if self.pre_gen_pte: cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + golden_start_with = "def hello_world():" p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() @@ -5939,26 +6043,20 @@ def test_static_gemma3_1b(self): self.fail(msg["Error"]) else: if not self.compile_only: - self.assertLessEqual(msg["wiki_ppl"], 23) + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: {golden_start_with}. Actual Output: {model_out}", + ) if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 1_200_000_000) # 1.2GB - inference_speed_ref = {"SM8650": 70, "SM8750": 100} - if ( - not self.compile_only - and not self.enable_x86_64 - and self.model in inference_speed_ref - ): - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) + self.assertLessEqual(pte_size, 1_200_000_000) # 1200MB + if not self.compile_only and not self.enable_x86_64: + self.assertGreaterEqual(msg["inference_speed"], 60) - def test_llama3_2_instruct(self): + def test_granite_3_3_2b_instruct(self): if not self.required_envs(): self.skipTest("missing required envs") - assert ( - self.llama_artifacts is not None - ), "Please provide path to llama artifacts" prompt = "What is the meaning of life?" cmds = [ @@ -5970,14 +6068,6 @@ def test_llama3_2_instruct(self): self.build_folder, "--model", self.model, - "--target", - self.target, - "--checkpoint", - f"{self.llama_artifacts}/consolidated.00.pth", - "--params", - f"{self.llama_artifacts}/params.json", - "--tokenizer_model", - f"{self.llama_artifacts}/tokenizer.model", "--ip", self.ip, "--port", @@ -5987,16 +6077,18 @@ def test_llama3_2_instruct(self): "--temperature", "0", "--decoder_model", - "llama3_2-1b_instruct", + "granite_3_3-2b_instruct", "--model_mode", "kv", "--max_seq_len", "1024", - "--eval_perplexity", + "--run_lm_eval", "--tasks", - "wikitext", + "hellaswag", "--limit", - "1", + "10", + "--kv_updater", + "shift_pointer", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -6017,14 +6109,14 @@ def test_llama3_2_instruct(self): if "Error" in msg: self.fail(msg["Error"]) else: - inference_speed_ref = {"SM8650": 37, "SM8750": 49} + inference_speed_ref = {"SM8650": 20, "SM8750": 22} if ( not self.compile_only and not self.enable_x86_64 and self.model in inference_speed_ref ): - self.assertLessEqual(msg["pte_size"], 1_500_000_000) - self.assertLessEqual(msg["wiki_ppl"], 15) + self.assertLessEqual(msg["pte_size"], 1_600_000_000) + self.assertGreaterEqual(msg["acc_norm"], 0.2) self.assertGreaterEqual( msg["inference_speed"], inference_speed_ref[self.model] ) @@ -6182,184 +6274,8 @@ def test_llama_stories_110m(self): if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai - def test_static_phi4(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "phi_4_mini", - "--model_mode", - "kv", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 14, "SM8750": 19} - self.assertLessEqual(msg["wiki_ppl"], 12) - self.assertLessEqual(msg["pte_size"], 4_000_000_000) # 4GB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - - def test_static_qwen2_5(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "qwen2_5-0_5b", - "--model_mode", - "kv", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 115, "SM8750": 155} - self.assertLessEqual(msg["wiki_ppl"], 15) - self.assertLessEqual(msg["pte_size"], 600_000_000) # 600MB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - - def test_static_qwen3(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "qwen3-0_6b", - "--model_mode", - "kv", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--tasks", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 38, "SM8750": 56} - self.assertLessEqual(msg["wiki_ppl"], 18) - self.assertLessEqual(msg["pte_size"], 950_000_000) # 950MB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - def test_qwen2_5(self): + # This is not testing static llm flow. if not self.required_envs([]): self.skipTest("missing required envs") prompt = "My favourite condiment is " @@ -6413,125 +6329,6 @@ def test_qwen2_5(self): f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'", ) - def test_static_smollm2(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "smollm2_135m", - "--model_mode", - "kv", - "--temperature", - "0", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--task", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - print("Perplexity score: ", msg["wiki_ppl"]) - self.assertLessEqual(msg["wiki_ppl"], 25) - if not self.enable_x86_64: - self.assertGreaterEqual(msg["inference_speed"], 200) - - def test_static_smollm3(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "My favourite condiment is " - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - "--prompt", - f"{prompt}", - "--decoder_model", - "smollm3-3b", - "--model_mode", - "kv", - "--temperature", - "0", - "--max_seq_len", - "1024", - "--eval_perplexity", - "--task", - "wikitext", - "--limit", - "1", - ] - if self.compile_only: - cmds.extend(["--compile_only"]) - elif self.device: - cmds.extend(["--device", self.device]) - if self.host: - cmds.extend(["--host", self.host]) - elif self.enable_x86_64: - cmds.extend(["--enable_x86_64"]) - if self.pre_gen_pte: - cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - inference_speed_ref = {"SM8650": 23, "SM8750": 28} - self.assertLessEqual(msg["wiki_ppl"], 10) - self.assertLessEqual(msg["pte_size"], 2_600_000_000) # 2.6GB - if self.model in inference_speed_ref: - self.assertGreaterEqual( - msg["inference_speed"], inference_speed_ref[self.model] - ) - class TestExampleOssScript(TestQNN): def test_albert(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 0f0c237a9e1..c0cc8daab03 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -178,6 +178,8 @@ class TestQNN(unittest.TestCase): dump_intermediate_outputs: bool = False inference_speed: float = 0.0 inference_speed_output_path = "outputs/inference_speed.txt" + model_name: str = "" + oss_repo: str = "" def _assert_outputs_equal(self, model_output, ref_output): self.assertTrue(len(ref_output) == len(model_output)) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index d26e9530f0b..20a1d3c0f72 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -58,6 +58,7 @@ EdgeProgramManager, to_edge_transform_and_lower, ) +from tabulate import tabulate from torch._decomp import core_aten_decompositions, remove_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes @@ -197,7 +198,7 @@ def dump_context_from_pte(pte_path) -> List[str]: with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program ctx_path = os.path.dirname(pte_path) dummy_compiler_specs = generate_qnn_executorch_compiler_spec( @@ -1089,9 +1090,11 @@ def generate_qnn_executorch_compiler_spec( def get_soc_to_arch_map(): return { "SA8295": HtpArch.V68, + "SM8350": HtpArch.V68, "SM8450": HtpArch.V69, "SM8475": HtpArch.V69, "SM8550": HtpArch.V73, + "SA8255": HtpArch.V73, "SM8650": HtpArch.V75, "SM8750": HtpArch.V79, "SSG2115P": HtpArch.V73, @@ -1107,9 +1110,11 @@ def get_soc_to_arch_map(): def get_soc_to_chipset_map(): return { "SA8295": QcomChipset.SA8295, + "SM8350": QcomChipset.SM8350, "SM8450": QcomChipset.SM8450, "SM8475": QcomChipset.SM8475, "SM8550": QcomChipset.SM8550, + "SA8255": QcomChipset.SA8255, "SM8650": QcomChipset.SM8650, "SM8750": QcomChipset.SM8750, "SSG2115P": QcomChipset.SSG2115P, @@ -1122,6 +1127,35 @@ def get_soc_to_chipset_map(): } +def show_nn_module_stack_for_quant_recipe(gm: torch.fx.GraphModule, supported_ops): + """ + Print a quick preview of op targets and module stack. + + Use this to inspect the FX graph and identify module stack, which helps you craft regex or op-target for quantization recipe. + + """ + + module_metadata = {} + for node in gm.graph.nodes: + target = node.target + deepest_module = None + if node.op == "call_function" and "nn_module_stack" in node.meta: + deepest_module = list(node.meta["nn_module_stack"].values())[-1][0] + if node.target in supported_ops: + module_metadata.setdefault((target, deepest_module), []).append(node) + + table_rows = [] + for (target, module_stack), nodes in module_metadata.items(): + node_names = ", ".join([node.name for node in nodes]) + table_rows.append([str(target), module_stack, node_names]) + + print( + tabulate( + table_rows, headers=["Op Target", "Module Stack", "Nodes"], tablefmt="grid" + ) + ) + + def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable): """ Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index 29394951bd7..f3c9ee75083 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -132,6 +132,8 @@ def all_flows() -> dict[str, TestFlow]: ARM_ETHOS_U85_FLOW, ARM_TOSA_FP_FLOW, ARM_TOSA_INT_FLOW, + ARM_VGF_FP_FLOW, + ARM_VGF_INT_FLOW, ) flows += [ @@ -139,6 +141,8 @@ def all_flows() -> dict[str, TestFlow]: ARM_TOSA_INT_FLOW, ARM_ETHOS_U55_FLOW, ARM_ETHOS_U85_FLOW, + ARM_VGF_FP_FLOW, + ARM_VGF_INT_FLOW, ] except Exception as e: logger.info(f"Skipping ARM flow registration: {e}") diff --git a/backends/test/suite/flows/arm.py b/backends/test/suite/flows/arm.py index a690e4681f8..29ef504d50c 100644 --- a/backends/test/suite/flows/arm.py +++ b/backends/test/suite/flows/arm.py @@ -5,19 +5,22 @@ # Create flows for Arm Backends used to test operator and model suits +from collections.abc import Callable + from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.quantizer import get_symmetric_quantization_config from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.util._factory import create_quantizer from executorch.backends.test.suite.flow import TestFlow from executorch.backends.xnnpack.test.tester.tester import Quantize def _create_arm_flow( - name, - compile_spec: ArmCompileSpec, + name: str, + compile_spec_factory: Callable[[], ArmCompileSpec], + support_serialize: bool = True, + quantize: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, use_portable_ops: bool = True, @@ -25,24 +28,23 @@ def _create_arm_flow( ) -> TestFlow: def _create_arm_tester(*args, **kwargs) -> ArmTester: - kwargs["compile_spec"] = compile_spec + spec = compile_spec_factory() + kwargs["compile_spec"] = spec return ArmTester( *args, **kwargs, use_portable_ops=use_portable_ops, timeout=timeout ) - support_serialize = not isinstance(compile_spec, TosaCompileSpec) - quantize = compile_spec.tosa_spec.support_integer() - - if quantize is True: + if quantize: def create_quantize_stage() -> Quantize: - quantizer = create_quantizer(compile_spec) + spec = compile_spec_factory() + quantizer = create_quantizer(spec) quantization_config = get_symmetric_quantization_config( is_per_channel=per_channel_quantization ) if symmetric_io_quantization: quantizer.set_io(quantization_config) - return Quantize(quantizer, quantization_config) + return Quantize(quantizer, quantization_config) # type: ignore return TestFlow( name, @@ -50,23 +52,41 @@ def create_quantize_stage() -> Quantize: tester_factory=_create_arm_tester, supports_serialize=support_serialize, quantize=quantize, - quantize_stage_factory=(create_quantize_stage if quantize is True else False), + quantize_stage_factory=(create_quantize_stage if quantize else False), # type: ignore ) ARM_TOSA_FP_FLOW = _create_arm_flow( "arm_tosa_fp", - common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), + support_serialize=False, + quantize=False, ) ARM_TOSA_INT_FLOW = _create_arm_flow( "arm_tosa_int", - common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + lambda: common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"), + support_serialize=False, + quantize=True, ) ARM_ETHOS_U55_FLOW = _create_arm_flow( "arm_ethos_u55", - common.get_u55_compile_spec(), + lambda: common.get_u55_compile_spec(), + quantize=True, ) ARM_ETHOS_U85_FLOW = _create_arm_flow( "arm_ethos_u85", - common.get_u85_compile_spec(), + lambda: common.get_u85_compile_spec(), + quantize=True, +) +ARM_VGF_FP_FLOW = _create_arm_flow( + "arm_vgf_fp", + lambda: common.get_vgf_compile_spec(tosa_spec="TOSA-1.0+FP"), + quantize=False, + use_portable_ops=False, +) +ARM_VGF_INT_FLOW = _create_arm_flow( + "arm_vgf_int", + lambda: common.get_vgf_compile_spec(tosa_spec="TOSA-1.0+INT"), + quantize=True, + use_portable_ops=False, ) diff --git a/backends/transforms/remove_clone_ops.py b/backends/transforms/remove_clone_ops.py index 01fe2ee26a4..79b93af8beb 100644 --- a/backends/transforms/remove_clone_ops.py +++ b/backends/transforms/remove_clone_ops.py @@ -25,8 +25,9 @@ class RemoveCloneOpsTransform(ExportPass): exir_ops.edge.dim_order_ops._clone_dim_order.default, } - def __init__(self) -> None: + def __init__(self, preserve_input_output_copies: bool = False) -> None: super().__init__() + self._preserve_input_output_copies = preserve_input_output_copies def _remove(self, graph_module: torch.fx.GraphModule) -> None: dequant_nodes = [] @@ -38,6 +39,11 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if self._is_non_identity_clone(n): continue + # If preserve_input_output_copies is set, don't remove clones that directly + # copy from input to output. + if self._is_input_output_copy(n) and self._preserve_input_output_copies: + continue + to_be_removed = n for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) @@ -76,3 +82,16 @@ def _is_non_identity_clone(self, node: torch.fx.Node) -> bool: ) return False + + def _is_input_output_copy(self, node: torch.fx.Node) -> bool: + """Return True if the node input is a graph input and output goes into an output node.""" + + input_node = node.args[0] + if input_node.op != "placeholder": + return False + + for users in node.users: + if users.op == "output": + return True + + return False diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index ae1a0b79654..453b4814637 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -63,19 +63,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "remove_local_scalar_dense", - srcs = ["remove_local_scalar_dense_ops.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir/dialects:lib", - ], -) - runtime.python_library( name = "remove_redundant_ops", srcs = ["remove_redundant_ops.py"], @@ -117,19 +104,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "replace_qdq", - srcs = ["replace_qdq.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//caffe2:torch", - "//executorch/backends/vulkan:utils_lib", - "//executorch/exir:pass_base", - ], -) - runtime.python_library( name = "fuse_patterns", srcs = ["fuse_patterns.py"], @@ -161,9 +135,7 @@ runtime.python_library( ":fuse_quantized_ops", ":insert_prepack_nodes", ":remove_asserts", - ":remove_local_scalar_dense", ":remove_redundant_ops", - ":replace_qdq", ":squeeze_unsqueeze_inputs", ":tag_memory_meta_pass", ] diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 169bd60543c..d6a6823ca88 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -16,13 +16,9 @@ remove_asserts, RemoveAssertsTransform, ) -from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( - RemoveLocalScalarDenseOpsTransform, -) from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) -from executorch.backends.vulkan._passes.replace_qdq import ReplaceQDQPass from executorch.backends.vulkan._passes.squeeze_unsqueeze_inputs import ( SqueezeUnsqueezeInputs, ) @@ -35,9 +31,7 @@ "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", - "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", - "ReplaceQDQPass", "SqueezeUnsqueezeInputs", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py deleted file mode 100644 index 6ce3572ec0c..00000000000 --- a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -from torch._subclasses.fake_tensor import FakeTensor - - -def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool: - """ - Converting a tensor to a scalar via tensor[0].item() creates a index_select + - local_scalar_dense pattern in the graph. Check if a node is the start of this pattern. - """ - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.select_copy.int - and len(node.users) == 1 - ): - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default - - return False - - -def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: - """ - A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar - value instead of a Tensor object. The criteria for identifying a tensor as a scalar - tensor are as follows: - - 1. The tensor has only 1 element - 2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which - creates a index_select + local_scalar_dense pattern in the graph - - If any of these criteria are fulfilled, then tag the node for the tensor to mark it - so that it is added as a scalar value during serialization. - """ - tensor_val = node.meta["val"] - if not isinstance(tensor_val, FakeTensor): - return - - # Scalar tensors must have only one element - if tensor_val.numel() != 1: - return - - for user in node.users: - if node_is_local_scalar_dense_chain(user): - node.meta["etvk_is_scalar_tensor"] = True - - -def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: - """ - Remove the index_select + local_scalar_dense pattern in the graph in favor of passing - the original scalar tensor directly. - """ - replace_node = node.args[0] - assert isinstance(replace_node, torch.fx.Node) - # If the argument to the local_scalar_dense op is a select op with only - # one user, and the argument to the select op is a tensor with only one - # element (i.e. a scalar tensor), then replace the entire pattern with the - # scalar tensor. - if ( - replace_node.op == "call_function" - and replace_node.target == exir_ops.edge.aten.select_copy.int - ): - # pyre-ignore - if replace_node.args[0].meta["val"].numel() == 1: - replace_node = replace_node.args[0] - assert isinstance(replace_node, torch.fx.Node) - assert replace_node.meta.get("etvk_is_scalar_tensor", True) - - with graph.inserting_after(node): - node.replace_all_uses_with(replace_node) - - -def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: - """ - The purpose of this pass is twofold: - 1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria) - 2. Remove the index_select + local_scalar_dense pattern in the graph in favor of - passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`) - - This makes it easier to deal with scalar tensors in the Vulkan backend. In particular, - it allows serializing scalar tensors as SymInt objects instead of Tensor objects. - Because scalar tensors are often used to inform tensor shapes, their values need to - be easily accessed by the CPU during resizing logic, while also being able to reflect - updates to their value in any GPU shaders that reference them. - """ - target_op = torch.ops.aten._local_scalar_dense.default - for node in graph.nodes: - tag_node_if_scalar_tensor(node) - - if node.op == "call_function" and node.target == target_op: - remove_local_scalar_dense_chain(graph, node) - - graph.eliminate_dead_code() - return graph - - -class RemoveLocalScalarDenseOpsTransform(ExportPass): - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) - return PassResult(graph_module, True) diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 8e602dd17b4..25bdd34de70 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten.expand_copy.default, } def __init__(self) -> None: super(RemoveRedundantOpsTransform, self).__init__() def _should_remove(self, node: torch.fx.Node) -> bool: - if node.target in self.redundant_ops: - return True - - # Only remove to_copy if dtype does not change. Otherwise, memory format changes - # will be handled internally by the backend. - if ( - node.target == exir_ops.edge.aten._to_copy.default - or node.target == torch.ops.aten._to_copy.default - ): - src_dtype = node.meta["val"].dtype - # pyre-ignore - dst_dtype = node.args[0].meta["val"].dtype - return src_dtype == dst_dtype - - return False + if node.target not in self.redundant_ops: + return False + + orig_node = node.args[0] + assert isinstance(orig_node, torch.fx.Node) + + src_dtype = orig_node.meta["val"].dtype + dst_dtype = node.meta["val"].dtype + + # Do not remove if the op is converting the dtype. + if src_dtype != dst_dtype: + return False + + src_shape = orig_node.meta["val"].shape + dst_shape = node.meta["val"].shape + + return src_shape == dst_shape def _remove(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.nodes: if not self._should_remove(node): continue - with graph_module.graph.inserting_after(node): - node.replace_all_uses_with(node.args[0]) + node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py deleted file mode 100644 index fcfcdfc4c18..00000000000 --- a/backends/vulkan/_passes/replace_qdq.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import executorch.backends.vulkan.utils as utils -import torch -from executorch.exir.dialects._ops import ops as exir_ops - -from executorch.exir.pass_base import ExportPass, PassResult - - -class ReplaceQDQPass(ExportPass): - """ - Replace standard quantize/dequantize ops with custom conv-specific ops when they - feed into/from quantized convolution operations. This optimization allows the - backend to handle quantization more efficiently for convolution operations. - """ - - def __init__(self): - super(ReplaceQDQPass, self).__init__() - - def call(self, graph_module: torch.fx.GraphModule): - # Track nodes that need to be replaced - nodes_to_replace = [] - - for node in graph_module.graph.nodes: - # Check if this is the custom quantized conv2d op - if node.target in [ - exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, - exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, - exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, - ]: - # Replace quantize op feeding into conv2d (first argument is the quantized input) - quantized_input_node = node.args[0] - if isinstance( - quantized_input_node, torch.fx.Node - ) and utils.is_quant_node(quantized_input_node): - # Get the arguments from the original quantize node - input_tensor = quantized_input_node.args[0] - scale = quantized_input_node.args[1] - zero_point = quantized_input_node.args[2] - - nodes_to_replace.append( - { - "old_node": quantized_input_node, - "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, - "args": (input_tensor, scale, zero_point), - "node_type": "quantize_input", - } - ) - - # Find dequantize ops that consume the output of this conv2d - for user in node.users: - if utils.is_dequant_node(user): - # Get the arguments from the original dequantize node - scale = user.args[1] - zero_point = user.args[2] - - nodes_to_replace.append( - { - "old_node": user, - "new_target": exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default, - "args": ( - node, - scale, - zero_point, - ), # node is the conv2d output - "node_type": "dequantize_output", - } - ) - - # Apply the replacements - for replacement in nodes_to_replace: - old_node = replacement["old_node"] - new_target = replacement["new_target"] - new_args = replacement["args"] - - with graph_module.graph.inserting_before(old_node): - new_node = graph_module.graph.create_node( - "call_function", new_target, args=new_args - ) - new_node.meta = old_node.meta.copy() - old_node.replace_all_uses_with(new_node) - - # Clean up the graph - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - - # Re-trace to validate everything is ok - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8ed71aa1dae..00b6c62d5d2 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,22 +6,16 @@ import logging import operator - from typing import Any import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec @@ -130,15 +124,17 @@ def __init__( texture_limits: utils.ImageExtents, default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + force_fp16: bool = False, ): super().__init__() self.default_storage: VkStorageType = default_storage_type self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits + self.force_fp16 = force_fp16 # Magic number to limit "lookahead" when tracing through users of an operator # to constrain the representation of its arguments/outputs. - self.max_trace_search_depth = 20 + self.max_trace_search_depth = None def is_valid_op_node(self, node: Any) -> bool: """ @@ -230,9 +226,10 @@ def get_arg_tensor_source_repset( """ arg_node = op_node.args[arg_i] - # For non-tensor arguments, return ANY_STORAGE + # For non-tensor arguments, return ALL_STORAGES_REPSET so that the respset does + # not appear to be empty. if not utils.is_tensor_arg_node(arg_node): - return utils.ANY_STORAGE + return utils.ALL_STORAGES_REPSET # Special case for cat - use the first tensor in the list as representative if isinstance(arg_node, list): @@ -361,12 +358,18 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No 2. Then, try to trace through the users of the argument to find a representation that can be used for as long as possible without needing a transition. """ + # If forcing fp16, then try to use texture storage whenever possible. This is + # a temporary stopgap measure until all buffer implementations properly account + # for potential overflow of fp16 representation range when doing math in fp16. + if self.force_fp16: + op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) arg_repset = op_repsets.get_arg_repset(arg_i) if arg_repset.is_constrained(): - return arg_repset + return arg_node = op_repsets.op_node.args[arg_i] @@ -376,6 +379,20 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: + """ + Similar to the `constrain_op_arg_repset` function, but for the output repset of + the operator. + """ + out_repset = op_repsets.get_out_repset(0) + if out_repset.is_constrained(): + return + + op_node = op_repsets.op_node + out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset) + + op_repsets.try_constrain_with_out_repset(out_respset) + def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: # For most ops, constraining the argument repsets will also contrain the output # repset due to OpRepSets maintaining synchronization rules. @@ -383,14 +400,12 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): self.constrain_op_arg_repset(i, op_repsets) - # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there - # is no need to constrain output repsets explicitly. Currently, the exceptions - # (i.e. choose qparams) already define constrined repsets for the output, so - # there is again no need to explicitly constrain the outputs. If an operator - # appears later on that does not sync input and output representations, and - # defines ambiguous repsets for the output tensor(s), then we will need to add - # additional logic to this function to constrain the output repsets separately - # from the input repsets. + # However, some operators do not sync input and output representations and also + # define ambiguous repsets for the output tensor(s). In those cases we will need + # to execute additional logic to constrain the output repsets separately from + # the input repsets. + if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr: + self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 6e5aa926d37..aed8b591fea 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -9,6 +9,8 @@ import executorch.backends.vulkan.patterns as vk_patterns import torch.library +from torch._subclasses.fake_tensor import FakeTensor + namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") @@ -537,42 +539,6 @@ def apply_rotary_emb_impl( lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) -############################# -## quantize/dequantize ops ## -############################# - - -def quantize_q8ta_for_conv2d_impl( - input: torch.Tensor, - scale: float, - zero_point: int, -): - return torch.ops.quantized_decomposed.quantize_per_tensor( - input, scale, zero_point, -128, 127, torch.int8 - ) - - -name = "quantize_q8ta_for_conv2d" -lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") -lib.impl(name, quantize_q8ta_for_conv2d_impl, "CompositeExplicitAutograd") -quantize_q8ta_for_conv2d_op = getattr(getattr(torch.ops, namespace), name) - - -def dequantize_q8to_from_conv2d_impl( - input: torch.Tensor, - scale: float, - zero_point: int, -): - return torch.ops.quantized_decomposed.dequantize_per_tensor( - input, scale, zero_point, -128, 127, input.dtype - ) - - -name = "dequantize_q8to_from_conv2d" -lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") -lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd") -dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name) - ######################## ## add_q8ta_q8ta_q8to ## ######################## @@ -614,3 +580,18 @@ def add_q8ta_q8ta_q8to_impl( ) lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd") add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## select_as_symint ## +############################# + + +def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): + assert isinstance(x, FakeTensor) + return x.fake_mode.shape_env.create_unbacked_symint() + + +name = "select_as_symint" +lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt") +lib.impl(name, select_as_symint_impl, "Meta") +select_as_symint_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index da127f72528..feba4f6f072 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -7,17 +7,12 @@ # pyre-unsafe import operator - from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa - import executorch.backends.vulkan.utils as utils - import torch - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.sub, operator.lt, operator.gt, operator.ge, @@ -148,13 +144,9 @@ def register_ephemeral_op(): @update_features( [ - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_token.default, ] ) @@ -297,27 +289,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) def register_to_copy_dim_order_op(): - # Currently there is no "real" implementation for to_dim_order_copy, but it can be - # removed as long as the operator is not changing the dtype, i.e. the operator call - # is modifying the dim order only. Therefore, check that the input and output dtypes - # are the same, if so the operator is safe to remove. - def check_dim_order_copy_node(node: torch.fx.Node) -> bool: - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if in_tensor.dtype != out_tensor.dtype: - return False - - return True - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_BUFFER, supports_resize=True, - are_node_inputs_supported_fn=check_dim_order_copy_node, ) @@ -652,35 +626,35 @@ def register_quantized_binary_op(): @update_features( [ - exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, ] ) -def register_quantize_for_conv2d_op(): +def register_quantize_op(): return OpFeatures( inputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], outputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, ], - supports_resize=False, ) @update_features( [ - exir_ops.edge.et_vk.dequantize_q8to_from_conv2d.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, ] ) -def register_dequantize_for_conv2d_op(): +def register_dequantize_op(): return OpFeatures( inputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, ], outputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], - supports_resize=False, ) @@ -709,7 +683,7 @@ def register_sdpa_ops(): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) @@ -733,6 +707,7 @@ def register_view_ops(): exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.gather.default, ] ) def register_view_ops_with_buffer_meta(): @@ -765,6 +740,7 @@ def register_cat_op(): [ exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.split_with_sizes_copy.default, ] ) def register_transfer_ops(): @@ -807,10 +783,7 @@ def register_ported_op(): # Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions @update_features( [ - # Tensor combination exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.split.Tensor, ] ) def register_ported_op_all_packed_dims(): diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 059b3a07be0..bc3bf14bf14 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -184,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: - """ - Scalar tensors are usually converted to scalar values in the graph via` - scalar_tensor[0].item()` in Python, which translates to a chain of - `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. - This function marks the entire chain as supported by the Vulkan delegate. - - Later, within vulkan_preprocess there will be a graph transform which replaces - the chain with passing in the scalar tensor directly. - - Similar to the `is_linear_permute` function, this function has 2 return values. - """ - if node.target == exir_ops.edge.aten.select_copy.int: - if len(node.users) != 1: - return False, False - # pyre-ignore - if node.args[0].meta["val"].numel() != 1: - return False, False - - local_scalar_dense = list(node.users.keys())[0] - if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: - return False, False - - return self.is_in_local_scalar_dense_chain(local_scalar_dense) - - if node.target == torch.ops.aten._local_scalar_dense.default: - return True, all(self.node_is_compatible(user)[0] for user in node.users) - - return False, False - def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( @@ -261,16 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "permute node of non compatible linear node") return False - ( - is_in_local_scalar_dense_chain, - dst_node_is_compatible, - ) = self.is_in_local_scalar_dense_chain(node) - if is_in_local_scalar_dense_chain and dst_node_is_compatible: - return True - elif is_in_local_scalar_dense_chain: - self.log_skip(node, "local scalar dense of incompatible op node") - return False - features = None if target not in vulkan_supported_ops: # For some ops, i.e. custom ops the name is registered instead of the diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index 285efe2b933..3baf7c9e251 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -12,6 +12,8 @@ runtime.python_library( "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "sdpa.py", + "select_as_symint.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index e23dfc7629c..9b875def944 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -14,6 +14,10 @@ import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.sdpa # noqa + +import executorch.backends.vulkan.patterns.select_as_symint # noqa + import torch from executorch.backends.vulkan.patterns.pattern_registry import ( diff --git a/backends/vulkan/patterns/sdpa.py b/backends/vulkan/patterns/sdpa.py new file mode 100644 index 00000000000..f67799f9b76 --- /dev/null +++ b/backends/vulkan/patterns/sdpa.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram + + +def is_update_cache_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::update_cache") + + +def is_custom_sdpa_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::custom_sdpa") + + +def is_sdpa_with_kv_cache_node(node: Any) -> bool: + return utils.node_has_target(node, "llama::sdpa_with_kv_cache") + + +class CausalSDPAMatch(PatternMatch): + def __init__(self, custom_sdpa_node: torch.fx.Node) -> None: + self.anchor_node = custom_sdpa_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # llama.custom_sdpa has signature: + # custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output + if len(custom_sdpa_node.args) < 4: + return + + self.query_node = custom_sdpa_node.args[0] + self.key_cache_node = custom_sdpa_node.args[1] + self.value_cache_node = custom_sdpa_node.args[2] + self.start_pos_node = custom_sdpa_node.args[3] + self.attn_mask_node = custom_sdpa_node.args[4] + self.dropout_p_node = custom_sdpa_node.args[5] + self.is_causal_node = custom_sdpa_node.args[6] + if len(custom_sdpa_node.args) > 7: + self.scale_node = custom_sdpa_node.args[7] + else: + self.scale_node = None + + # try to find update key cache node + self.update_key_cache_node = None + for user in self.key_cache_node.users: + if is_update_cache_node(user): + self.update_key_cache_node = user + break + + self.key_projection_node = None + if self.update_key_cache_node is not None: + self.key_projection_node = self.update_key_cache_node.args[0] + + # find update value cache node + self.update_value_cache_node = None + for user in self.value_cache_node.users: + if is_update_cache_node(user): + self.update_value_cache_node = user + break + + self.value_projection_node = None + if self.update_value_cache_node is not None: + self.value_projection_node = self.update_value_cache_node.args[0] + + # We have additional optional arguments but we don't need to capture them + # since the new op doesn't use them + + self.match_found = True + + +@register_pattern_detector("causal_sdpa") +def find_causal_sdpa_patterns( + node: torch.fx.Node, +) -> Optional[CausalSDPAMatch]: + if not is_custom_sdpa_node(node): + return None + + matched_pattern = CausalSDPAMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + if is_update_cache_node(node): + return node.args[2] + + if is_sdpa_with_kv_cache_node(node): + return node.args[5] + + raise Exception( + "Could not find an instance of llama::update_cache or sdpa_with_kv_cache" + ) + + +@register_pattern_replacement("causal_sdpa") +def replace_custom_sdpa_with_causal_sdpa( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: CausalSDPAMatch, +): + assert match.update_key_cache_node is not None + assert match.key_projection_node is not None + assert match.update_value_cache_node is not None + assert match.value_projection_node is not None + + singleton_start_pos_node = find_singleton_start_pos_node(graph_module) + + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + torch.ops.llama.sdpa_with_kv_cache.default, + args=( + match.query_node, + match.key_projection_node, + match.value_projection_node, + match.key_cache_node, + match.value_cache_node, + singleton_start_pos_node, + 1, + match.attn_mask_node, + match.dropout_p_node, + match.is_causal_node, + match.scale_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # Manually erase update_cache nodes since DCE will not remove them since they + # modify inputs (specifically, the cache args are modified) + graph_module.graph.erase_node(match.update_key_cache_node) + graph_module.graph.erase_node(match.update_value_cache_node) diff --git a/backends/vulkan/patterns/select_as_symint.py b/backends/vulkan/patterns/select_as_symint.py new file mode 100644 index 00000000000..e7226b08188 --- /dev/null +++ b/backends/vulkan/patterns/select_as_symint.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class SelectAsSymIntMatch(PatternMatch): + def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None: + self.anchor_node = local_scalar_dense_node + self.match_found = False + + # Check if the input to local_scalar_dense is a select_copy node + if len(local_scalar_dense_node.args) < 1: + return + + select_node = local_scalar_dense_node.args[0] + if not isinstance(select_node, torch.fx.Node): + return + + if ( + select_node.op != "call_function" + or select_node.target != exir_ops.edge.aten.select_copy.int + ): + return + + # select_copy.int has signature: select_copy(Tensor self, int dim, int index) + if len(select_node.args) < 3: + return + + self.select_node = select_node + + self.tensor_node = select_node.args[0] + self.dim_node = select_node.args[1] + self.index_node = select_node.args[2] + + self.all_nodes = [ + self.anchor_node, + self.select_node, + self.tensor_node, + self.dim_node, + self.index_node, + ] + + self.match_found = True + + +@register_pattern_detector("select_as_symint") +def find_select_as_symint_patterns( + node: torch.fx.Node, +) -> Optional[SelectAsSymIntMatch]: + if node.target != torch.ops.aten._local_scalar_dense.default: + return None + + matched_pattern = SelectAsSymIntMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("select_as_symint") +def replace_select_local_scalar_dense_with_select_as_symint( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: SelectAsSymIntMatch, +): + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.select_as_symint.default, + args=( + match.tensor_node, + match.dim_node, + match.index_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # # Remove both the local_scalar_dense and select_copy nodes + # graph_module.graph.erase_node(match.anchor_node) + # # Only erase select_node if it has no other users + # if len(match.select_node.users) == 0: + # graph_module.graph.erase_node(match.select_node) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index fe8cc83c481..fecef2598c7 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -179,6 +179,12 @@ GraphConfig get_graph_config(ArrayRef& compile_specs) { config.expect_dynamic_shapes = true; } } + if (strcmp(spec.key, "warmup_execute_after_compile") == 0) { + ET_CHECK_MSG(value_size == sizeof(uint8_t), "Unexpected value size!"); + bool value = getBool(value_data); + + config.warmup_execute_after_compile = value; + } } #ifdef ET_EVENT_TRACER_ENABLED config.enable_querypool = true; @@ -579,6 +585,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->prepack(); + compute_graph->optional_warmup_execute(); + return Error::Ok; } @@ -649,7 +657,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - if (should_propagate_resize) { + if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 6d0e5a4a457..09788e66b0f 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -112,6 +112,20 @@ class StagingBuffer final { inline void set_staging_zeros() { memset(data(), 0, nbytes()); } + + template + T select_element_at_dim( + const std::vector& sizes, + const int64_t dim, + const int64_t index) { + int64_t stride = 1; + for (size_t i = dim + 1; i < sizes.size(); ++i) { + stride *= sizes[i]; + } + const int64_t offset = index * stride; + const T* typed_data = reinterpret_cast(data()); + return typed_data[offset]; + } }; } // namespace api diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 2ec63a89df8..f96dbd6848f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -683,6 +683,17 @@ int32_t ComputeGraph::read_symint(const ValueRef idx) { return get_symint(idx)->get(); } +ValueRef ComputeGraph::staging_of(const ValueRef idx) { + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i].value == idx) { + if (is_valid(inputs_[i].staging)) { + return inputs_[i].staging; + } + } + } + VK_THROW("Could not find staging buffer for value at index ", idx); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); @@ -1096,6 +1107,12 @@ void ComputeGraph::prepack() { } } +void ComputeGraph::optional_warmup_execute() { + if (config_.warmup_execute_after_compile) { + execute(); + } +} + void ComputeGraph::execute() { if (deferred_cmd_list_.empty()) { context_->flush(); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index dbd5536279c..7415a9dd2df 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -639,6 +639,10 @@ class ComputeGraph final { bool device_name_contains(const char* substr); + int64_t max_buffer_numel() { + return static_cast(context_->adapter_ptr()->max_buffer_numel()); + } + // // Graph Building // @@ -824,6 +828,8 @@ class ComputeGraph final { inputs_.push_back({idx, kDummyValueRef}); } + ValueRef staging_of(const ValueRef idx); + inline void set_val_as_output(const ValueRef idx) { outputs_.push_back({idx, kDummyValueRef}); } @@ -1027,6 +1033,12 @@ class ComputeGraph final { */ void prepack(); + // + // Optional Graph Execution + // + + void optional_warmup_execute(); + // // Graph Execution // @@ -1081,6 +1093,14 @@ class ComputeGraph final { return can_use_int8_dot_product_; } + inline void set_has_data_dependent_shapes() { + config_.has_data_dependent_shapes = true; + } + + inline bool has_data_dependent_shapes() const { + return config_.has_data_dependent_shapes; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index 20b8f6f7c00..9a919a42573 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -64,6 +64,7 @@ GraphConfig::GraphConfig() { enable_local_wg_size_override = false; local_wg_size_override = {}; + has_data_dependent_shapes = false; expect_dynamic_shapes = false; force_resize = false; diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index 7533df3b685..20d01362ef1 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -33,8 +33,11 @@ struct GraphConfig final { bool enable_local_wg_size_override; utils::uvec3 local_wg_size_override; + // If true, then resize functions should always be called even if input shapes + // have not changed. + bool has_data_dependent_shapes = false; // Whether or not the ComputeGraph should expect input shapes to be dynamic - bool expect_dynamic_shapes; + bool expect_dynamic_shapes = false; // Used for testing/debugging only. Forces ExecuteNode to trigger the resize // function even if none of the inputs have been updated. bool force_resize = false; @@ -68,6 +71,10 @@ struct GraphConfig final { // many command buffers. size_t execute_max_cmds = 0; + // If true, then the graph will be executed once immediately after it is + // compiled. + bool warmup_execute_after_compile = false; + vkapi::Adapter* external_adapter; // Generate a default graph config with pre-configured settings diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index aa46ee76336..40cc67517ea 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -14,15 +14,19 @@ ExecuteNode::ExecuteNode( const ResizeFunction& resize_fn, const std::vector& resize_args, const std::vector& args, - const std::string& name) + const std::string& name, + const bool has_data_dependent_shape) : resize_fn_(resize_fn), resize_args_(resize_args), args_(args), - name_(name) {} + name_(name), + has_data_dependent_shape_(has_data_dependent_shape) {} bool ExecuteNode::trigger_resize(ComputeGraph* graph) { bool any_arg_updated = was_any_arg_updated(graph); - if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) { + if (resize_fn_ && + (any_arg_updated || graph->graphconfig().force_resize || + has_data_dependent_shape_)) { resize_fn_(graph, args_, resize_args_); any_arg_updated = true; } diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 323036cef90..4dbad882dea 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -57,7 +57,8 @@ class ExecuteNode { const ResizeFunction& resize_fn = nullptr, const std::vector& resize_args = {}, const std::vector& args = {}, - const std::string& name = "Graph Node"); + const std::string& name = "Graph Node", + const bool has_data_dependent_shape = false); virtual ~ExecuteNode() = default; @@ -87,6 +88,7 @@ class ExecuteNode { const std::vector resize_args_; const std::vector args_; const std::string name_; + bool has_data_dependent_shape_ = false; }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl index 8b69642d2e9..d0bd1809d11 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl @@ -25,8 +25,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#extension GL_EXT_debug_printf : enable -#define DEBUG_MODE #include "indexing.glslh" #include "common.glslh" diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl index 971f66f93e5..4f51e9ff679 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl @@ -22,7 +22,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl deleted file mode 100644 index 7e21bcf0eba..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} -#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(SCALE_OUT_DTYPE)} -${define_required_extensions(ZP_OUT_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - float eps; - }; -$if MODE == "per_token": - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - layout(push_constant) uniform BlockPC { - ivec4 blockSize; // WHCN (>=1) - ivec4 numBlocks; // #blocks along W,H,C,N - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP - int quant_min; - int quant_max; - float eps; - }; - -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} -${layout_declare_ubo(B, "ivec4", "t_scale_strides")} -${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} -${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} - -#include "indexing_utils.h" -#include "choose_qparams.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -#define NWORKERS 64 - -// Shared memory for reduction - must match local work group size -shared float shared_min[NWORKERS]; -shared float shared_max[NWORKERS]; - -/* - Quantization Parameter Computation Shader (Buffer Storage) - This shader computes quantization parameters (scale and zero_point) for converting - floating-point tensors to n-bit integer representations while preserving the - original data range as much as possible. The computed parameters enable efficient - quantization by mapping the continuous floating-point range to discrete integer values. - - Important Considerations: - (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - - Workgroup Configuration: - - choose_qparams_per_tensor - This mode computes a single set of quantization parameters for the entire tensor. - Uses parallel reduction across all threads to find global min/max values. - - (*) global_wg_size: {1, 1, 1} (single workgroup processes entire tensor) - (*) local_wg_size: {64, 1, 1} (matches NWORKERS for shared memory) - - - choose_qparams_per_token - This mode computes separate quantization parameters for each token in the tensor. - Each workgroup processes one token independently to find token-specific min/max. - - (*) global_wg_size: {num_tokens, 1, 1} (one workgroup per token) - (*) local_wg_size: {1, 1, 1} (single thread per token) - - - choose_qparams_block_wise - This mode computes quantization parameters for each block of elements, allowing - fine-grained control over quantization granularity within the tensor. Each block - is processed independently to find its own min/max values and compute corresponding - scale and zero_point parameters. - - (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) - (*) local_wg_size: {1, 1, 1} (single thread per block) - - Block-wise quantization supports multiple mapping types for scale/zero_point calculation: - - - mapping_type = 0 (ASYMMETRIC): - Uses asymmetric quantization where the full floating-point range [min, max] is - mapped to the quantized range [quant_min, quant_max]. This preserves the original - data distribution but may not center zero optimally. - - Calculation: - scale = (max - min) / (quant_max - quant_min) - zero_point = quant_min - round(min / scale) - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - scale = (10.2 - (-3.5)) / (7 - (-8)) = 13.7 / 15 = 0.913 - zero_point = -8 - round(-3.5 / 0.913) = -8 - (-4) = -4 - - - mapping_type = 1 (SYMMETRIC): - Uses symmetric quantization where the range is centered around zero. The scale - is computed based on the maximum absolute value, ensuring zero is exactly - representable in the quantized domain. - - Calculation: - max_abs = max(abs(min), abs(max)) - scale = max_abs / ((quant_max - quant_min) / 2) - zero_point = (quant_max + quant_min + 1) / 2 // midpoint - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - max_abs = max(3.5, 10.2) = 10.2 - scale = 10.2 / ((7 - (-8)) / 2) = 10.2 / 7.5 = 1.36 - zero_point = (-8 + 7 + 1) / 2 = 0 - - - mapping_type = 2 (SYMMETRIC_NO_CLIPPING_ERR): - A variant of symmetric quantization that minimizes clipping errors by computing - separate scales for positive and negative ranges, then using the maximum. This - reduces quantization error on the dominant range while ensuring no values are - clipped. - - Calculation: - smin = abs(min) / abs(quant_min) // scale for negative range - smax = max / quant_max // scale for positive range - scale = max(smin, smax) // use larger scale to avoid clipping - zero_point = (quant_max + quant_min + 1) / 2 // midpoint - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - smin = 3.5 / 8 = 0.4375 - smax = 10.2 / 7 = 1.457 - scale = max(0.4375, 1.457) = 1.457 // use smax to avoid clipping positives - zero_point = (-8 + 7 + 1) / 2 = 0 - - Tree Reduction Algorithm for Min/Max Finding: - The shader uses a parallel tree reduction algorithm to efficiently find minimum and - maximum values across multiple threads. This approach reduces the number of memory - accesses and synchronization points compared to sequential scanning. - - Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: - - Step 1 - Initial Population: - Each thread loads its assigned value into shared memory arrays. - shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - - Step 2 - Stride 1 (Compare Adjacent Pairs): - Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. - shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - Active: | 0 | | 2 | | 4 | | 6 | | - - Step 3 - Stride 2 (Compare Pairs of Pairs): - Threads 0,4 compare with threads 2,6 respectively. - shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) - shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - Active: | 0 | | | | 4 | | | | - - Step 4 - Stride 4 (Final Comparison): - Thread 0 compares with thread 4 to get final result. - shared_min: | 0 | | | | | | | | (min(1,0) = 0) - shared_max: | 10 | | | | | | | | (max(10,5) = 10) - Active: | 0 | | | | | | | | - - Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - - The tree reduction completes in log_2(N) steps where N is the number of threads, - providing O(log N) time complexity instead of O(N) for sequential reduction. - - Quantization Parameter Calculation: - Once min/max values are determined, the shader computes: - - scale = (max - min) / (quant_max - quant_min) - - zero_point = quantization offset to map floating-point zero to integer range - - Mode-Specific Behavior: - - Per-Tensor: Single workgroup with strided access across entire tensor - - Per-Token: Multiple workgroups, each processing one token independently - - Block-Wise: Each thread processes assigned blocks using nested loops over block dimensions -*/ - -#ifdef per_tensor - -void choose_qparams_per_tensor() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - - // Each thread processes multiple elements with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - for (uint i = global_id; i < total_elements; i += total_threads) { - float val = t_in[i]; - if (!isnan(val) && !isinf(val)) { - if (!found_valid) { - thread_min = val; - thread_max = val; - found_valid = true; - } else { - thread_min = min(thread_min, val); - thread_max = max(thread_max, val); - } - } - } - - // Intra-group reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final result calculation (single workgroup only) - if (local_id == 0) { - float global_min = shared_min[0]; - float global_max = shared_max[0]; - - float scale_val; - int zero_point_val; - // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant - calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - - t_scale[0] = SCALE_OUT_T(scale_val); - t_zero_point[0] = ZP_OUT_T(zero_point_val); - } -} - -#elif defined(per_token) - -void choose_qparams_per_token() { - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - uint token_size = total_elements / uint(num_tokens); - - const uint TOTAL_TOKENS = uint(num_tokens); - - /* each invocation handles token-ids: id, id+STRIDE, id+2·STRIDE … */ - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - for (uint token_id = gl_GlobalInvocationID.x; token_id < TOTAL_TOKENS; token_id += STRIDE) { - // Calculate the start and end indices for this token - uint token_start = token_id * token_size; - uint token_end = token_start + token_size; - - // Each thread processes the entire token - float lo = 1.0/0.0; // +INF - float hi = -1.0/0.0; // -INF - bool found_valid = false; - - // Process all elements in this token - for (uint i = token_start; i < token_end; i++) { - float val = t_in[i]; - if (!isnan(val) && !isinf(val)) { - if (!found_valid) { - lo = hi = val; - found_valid = true; - } else { - lo = min(lo, val); - hi = max(hi, val); - } - } - } - - if (!found_valid) { - // If no valid values were found, use default values - lo = 0.0; - hi = 0.0; - } - - // Calculate scale and zero point directly - float scale_val; - int zero_point_val; - // Use default values: mapping_type=0 (ASYMMETRIC), eps=1e-5 - calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - - // Write results - t_scale[token_id] = SCALE_OUT_T(scale_val); - t_zero_point[token_id] = ZP_OUT_T(zero_point_val); - } -} - -#elif defined(block_wise) - -ivec4 block_id_to_coord(uint bid) { - ivec4 bc; - bc.w = int(bid) / blockStride.w; - - int r = int(bid) - bc.w * blockStride.w; - bc.z = r / blockStride.z; - - r -= bc.z * blockStride.z; - bc.y = r / blockStride.y; - - r -= bc.y * blockStride.y; - bc.x = r; - return bc; -} - -void choose_qparams_block_wise() { - const uint TOTAL_BLOCKS = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); - - // each invocation handles block-ids: id, id+STRIDE, id+2·STRIDE - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE) { - // block -> WHCN coordinate - ivec4 bc = block_id_to_coord(block_id); - ivec4 blockStart = bc * blockSize; // first element (inclusive) - ivec4 blockEnd = blockStart + blockSize; // last element (exclusive) - - // min / max scan over the block - float lo = 1.0/0.0; // +INF - float hi = -1.0/0.0; // -INF - bool found_valid = false; - - // Calculate actual block dimensions - ivec4 actualBlockSize = blockEnd - blockStart; - int blockElements = actualBlockSize.x * actualBlockSize.y * actualBlockSize.z * actualBlockSize.w; - - // Linear iteration over block elements - for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { - // Convert linear index to 4D coordinates within block - int remaining = elemIdx; - int dn = remaining / (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); - remaining -= dn * (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); - int dc = remaining / (actualBlockSize.x * actualBlockSize.y); - remaining -= dc * (actualBlockSize.x * actualBlockSize.y); - int dh = remaining / actualBlockSize.x; - int dw = remaining - dh * actualBlockSize.x; - - ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); - uint idx = tidx_to_bufi(tidx, t_in_strides); - float v = t_in[idx]; - - if (!isnan(v) && !isinf(v)) { - if (!found_valid) { - lo = hi = v; - found_valid = true; - } else { - lo = min(lo, v); - hi = max(hi, v); - } - } - } - - // Handle the case where no valid values were found in the block - if (!found_valid) { - lo = 0.0; - hi = 0.0; - } - - float scale_val; - int zero_point_val; - calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale_val, zero_point_val); - - t_scale[block_id] = SCALE_OUT_T(scale_val); - t_zero_point[block_id] = ZP_OUT_T(zero_point_val); - } -} - -#endif - -void main() { - choose_qparams_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml deleted file mode 100644 index 8459b043baa..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ /dev/null @@ -1,22 +0,0 @@ -choose_qparams_buffer: - parameter_names_with_default_values: - IN_DTYPE: float - SCALE_OUT_DTYPE: float - ZP_OUT_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: float - SCALE_OUT_DTYPE: - - VALUE: float - ZP_OUT_DTYPE: - - VALUE: int32 - - VALUE: int8 - - VALUE: float - shader_variants: - - NAME: choose_qparams_tensor_buffer - MODE: per_tensor - - NAME: choose_qparams_per_token_asymmetric_buffer - MODE: per_token - - NAME: choose_qparams_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl deleted file mode 100644 index a17a3ae41dd..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ /dev/null @@ -1,533 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} -#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} -#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(SCALE_OUT_DTYPE)} -${define_required_extensions(ZP_OUT_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -$if MODE != "block_wise": - ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "texture3d")} - ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "texture3d")} -$else: - ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} - ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} - -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - float eps; - }; -$if MODE == "per_token": - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - layout(push_constant) uniform BlockPC { - ivec4 blockSize; // WHCN (>=1) - ivec4 numBlocks; // #blocks along W,H,C,N - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP - int quant_min; - int quant_max; - float eps; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -$if MODE != "block_wise": - ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} - ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} -$else: - ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} - ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} - ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} - ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} - - -#include "indexing_utils.h" -#include "choose_qparams.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -#define NWORKERS 64 - -// Shared memory for reduction - must match local work group size -shared float shared_min[NWORKERS]; -shared float shared_max[NWORKERS]; - -/*/* - Quantization Parameter Computation Shader (Buffer Storage) - This shader computes quantization parameters (scale and zero_point) for converting - floating-point tensors to n-bit integer representations while preserving the - original data range as much as possible. The computed parameters enable efficient - quantization by mapping the continuous floating-point range to discrete integer values. - - Important Considerations: - (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - - Workgroup Configuration: - - choose_qparams_per_tensor - This mode computes a single set of quantization parameters for the entire tensor. - Uses parallel reduction across all threads to find global min/max values. - - (*) global_wg_size: default - (*) local_wg_size: default - - - choose_qparams_per_token - This mode computes separate quantization parameters for each token in the tensor. - Each workgroup processes one token independently to find token-specific min/max. - - (*) global_wg_size: default - (*) local_wg_size: {1, 1, 1} - - - choose_qparams_block_wise - This mode computes quantization parameters for each block of elements, allowing - fine-grained control over quantization granularity within the tensor. Each block - is processed independently to find its own min/max values and compute corresponding - scale and zero_point parameters. - - NOTE: This mode currently only supports buffer storage for the output. - - (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) - (*) local_wg_size: {1, 1, 1} (single thread per block) - - Tree Reduction Algorithm for Min/Max Finding: - The shader uses a parallel tree reduction algorithm to efficiently find minimum and - maximum values across multiple threads. This approach reduces the number of memory - accesses and synchronization points compared to sequential scanning. - - Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: - - Step 1 - Initial Population: - Each thread loads its assigned value into shared memory arrays. - shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - - Step 2 - Stride 1 (Compare Adjacent Pairs): - Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. - shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - Active: | 0 | | 2 | | 4 | | 6 | | - - Step 3 - Stride 2 (Compare Pairs of Pairs): - Threads 0,4 compare with threads 2,6 respectively. - shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) - shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - Active: | 0 | | | | 4 | | | | - - Step 4 - Stride 4 (Final Comparison): - Thread 0 compares with thread 4 to get final result. - shared_min: | 0 | | | | | | | | (min(1,0) = 0) - shared_max: | 10 | | | | | | | | (max(10,5) = 10) - Active: | 0 | | | | | | | | - - Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - - The tree reduction completes in log_2(N) steps where N is the number of threads, - providing O(log N) time complexity instead of O(N) for sequential reduction. - - Quantization Parameter Calculation: - Once min/max values are determined, the shader computes: - - scale = (max - min) / (quant_max - quant_min) - - zero_point = quantization offset to map floating-point zero to integer range - - Mode-Specific Behavior: - - Per-Tensor: Single workgroup with strided access across entire tensor - - Per-Token: Multiple workgroups, each processing one token independently -*/ - -#ifdef per_tensor - -void choose_qparams_per_tensor() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - - uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - - // Each thread processes multiple texels with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - // Process texels with stride across all threads - for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { - // Convert linear texel index to 3D coordinates - uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); - uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); - uint y = remainder / uint(t_in_limits.x); - uint x = remainder % uint(t_in_limits.x); - ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - - FVEC4_T texel_data = load_texel(t_in, texel_pos); - - // For texture storage, we assume width-packed (packed_dim = 0) - // Calculate number of valid elements in this texel (handle padding) - int packed_dim = 0; // Width dimension is packed - ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format - ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); - - // Calculate total tensor elements to determine padding - int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; - int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + - tensor_coord.z * sizes.x * sizes.y; - int remaining_elements = total_elements - (linear_tensor_idx); - int valid_elements = min(4, remaining_elements); - - // Find min/max within this texel, considering only valid elements - if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { - if (!found_valid) { - thread_min = texel_data.x; - thread_max = texel_data.x; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.x); - thread_max = max(thread_max, texel_data.x); - } - } - - if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { - if (!found_valid) { - thread_min = texel_data.y; - thread_max = texel_data.y; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.y); - thread_max = max(thread_max, texel_data.y); - } - } - - if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { - if (!found_valid) { - thread_min = texel_data.z; - thread_max = texel_data.z; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.z); - thread_max = max(thread_max, texel_data.z); - } - } - - if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { - if (!found_valid) { - thread_min = texel_data.w; - thread_max = texel_data.w; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.w); - thread_max = max(thread_max, texel_data.w); - } - } - } - - // Intra-workgroup reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final result calculation (single workgroup only for reliability) - if (local_id == 0 && group_id == 0) { - float global_min = shared_min[0]; - float global_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - - write_texel(t_scale, ivec3(0, 0, 0), vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); - write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); - } -} - -#elif defined(per_token) - -void choose_qparams_per_token() { - // Each token is processed by multiple workgroups for parallel reduction - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_workgroups = gl_NumWorkGroups.x; - - uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - - // Calculate texels per token (assuming last dimension contains the token data) - // For per-token quantization, we assume tokens are along the last dimension - uint texels_per_token = total_texels / uint(num_tokens); - - // Calculate how many tokens each workgroup should process - uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; - - // Calculate which tokens this workgroup is responsible for - uint start_token = group_id * tokens_per_workgroup; - uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); - - // Process each token assigned to this workgroup - for (uint token_id = start_token; token_id < end_token; token_id++) { - // Calculate the texel range for this token - uint token_start_texel = token_id * texels_per_token; - uint token_end_texel = token_start_texel + texels_per_token; - - // Each thread processes multiple texels within the token - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - // Process texels within this token only - for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { - // Convert linear texel index to 3D coordinates - uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); - uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); - uint y = remainder / uint(t_in_limits.x); - uint x = remainder % uint(t_in_limits.x); - ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - - FVEC4_T texel_data = load_texel(t_in, texel_pos); - - // For texture storage, we assume width-packed (packed_dim = 0) - // Calculate number of valid elements in this texel (handle padding) - int packed_dim = 0; // Width dimension is packed - ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format - ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); - - // Calculate total tensor elements to determine padding - int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; - int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + - tensor_coord.z * sizes.x * sizes.y; - int remaining_elements = total_elements - (linear_tensor_idx); - int valid_elements = min(4, remaining_elements); - - // Find min/max within this texel, considering only valid elements - if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { - if (!found_valid) { - thread_min = texel_data.x; - thread_max = texel_data.x; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.x); - thread_max = max(thread_max, texel_data.x); - } - } - - if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { - if (!found_valid) { - thread_min = texel_data.y; - thread_max = texel_data.y; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.y); - thread_max = max(thread_max, texel_data.y); - } - } - - if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { - if (!found_valid) { - thread_min = texel_data.z; - thread_max = texel_data.z; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.z); - thread_max = max(thread_max, texel_data.z); - } - } - - if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { - if (!found_valid) { - thread_min = texel_data.w; - thread_max = texel_data.w; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.w); - thread_max = max(thread_max, texel_data.w); - } - } - } - - // Intra-workgroup reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - // Handle infinity values properly - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final calculation for this token - if (local_id == 0) { - float token_min = shared_min[0]; - float token_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calc_scale_zp(token_min, token_max, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - - // Convert token_id to 3D coordinates for output texture - // Assuming output tensors have the same layout as input but with different dimensions - uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); - uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); - uint out_y = out_remainder / uint(t_scale_limits.x); - uint out_x = out_remainder % uint(t_scale_limits.x); - ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); - - write_texel(t_scale, out_pos, vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); - write_texel(t_zero_point, out_pos, ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); - } - - // Synchronize before processing next token - barrier(); - } -} - -#elif defined(block_wise) - -ivec4 block_id_to_coord(uint bid) { - ivec4 bc; - bc.w = int(bid) / blockStride.w; - - int r = int(bid) - bc.w * blockStride.w; - bc.z = r / blockStride.z; - - r -= bc.z * blockStride.z; - bc.y = r / blockStride.y; - - r -= bc.y * blockStride.y; - bc.x = r; - return bc; -} - -void choose_qparams_block_wise() { - const uint T = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - - // tensor full size in WHCN order - const ivec4 tensorSz = blockSize * numBlocks; - - // Process blocks with stride for better parallelization - for (uint blkIdx = gl_GlobalInvocationID.x; blkIdx < T; blkIdx += STRIDE) { - // block index in WHCN - const ivec4 b4d = block_id_to_coord(blkIdx); - const ivec4 blockStart = b4d * blockSize; - const ivec4 blockEnd = blockStart + blockSize; - - // scan all elements inside the block - float vmin = 3.402823e38; // +FLT_MAX - float vmax = -3.402823e38; // -FLT_MAX - bool found_valid = false; - - // Calculate total elements in block for linear iteration - const int blockElements = blockSize.x * blockSize.y * blockSize.z * blockSize.w; - - // Linear iteration over block elements (more cache-friendly) - for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { - // Convert linear index to 4D coordinates within block - int remaining = elemIdx; - int dn = remaining / (blockSize.x * blockSize.y * blockSize.z); - remaining -= dn * (blockSize.x * blockSize.y * blockSize.z); - int dc = remaining / (blockSize.x * blockSize.y); - remaining -= dc * (blockSize.x * blockSize.y); - int dh = remaining / blockSize.x; - int dw = remaining - dh * blockSize.x; - - ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); - - // skip padding when tensor size is not an exact multiple of block - if (any(greaterThanEqual(tidx, tensorSz))) { continue; } - - // tensor index -> (x,y,z,component) inside input texture - ivec4 posi = to_texture_elem_pos(tidx, tensorSz, 0); // 0 = W_DIM (width packed) - - // fetch texel and pick the element inside it - FVEC4_T texl = load_texel(t_in, posi.xyz); - float v; - if (posi.w == 0) v = texl.x; - else if (posi.w == 1) v = texl.y; - else if (posi.w == 2) v = texl.z; - else v = texl.w; - - if (!isnan(v) && !isinf(v)) { - if (!found_valid) { - vmin = vmax = v; - found_valid = true; - } else { - vmin = min(vmin, v); - vmax = max(vmax, v); - } - } - } - - // Handle case where no valid values were found - if (!found_valid) { - vmin = 0.0; - vmax = 0.0; - } - - // compute scale / zero‑point (same maths as buffer kernel) - float scale; - int zp; - calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); - - // Write the scalar values directly to buffer using linear index - t_scale[blkIdx] = SCALE_OUT_T(scale); - t_zero_point[blkIdx] = ZP_OUT_T(zp); - } -} - -#endif - -void main() { - choose_qparams_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml deleted file mode 100644 index 12228822d4b..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ /dev/null @@ -1,22 +0,0 @@ -choose_qparams_texture: - parameter_names_with_default_values: - IN_DTYPE: float - SCALE_OUT_DTYPE: float - ZP_OUT_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: float - SCALE_OUT_DTYPE: - - VALUE: float - ZP_OUT_DTYPE: - - VALUE: int32 - - VALUE: int8 - - VALUE: float - shader_variants: - - NAME: choose_qparams_tensor_texture3d - MODE: per_tensor - - NAME: choose_qparams_per_token_asymmetric_texture3d - MODE: per_token - - NAME: choose_qparams_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/clone.yaml b/backends/vulkan/runtime/graph/ops/glsl/clone.yaml index 1fdbf506bfd..a85d201046e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/clone.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/clone.yaml @@ -7,5 +7,7 @@ clone: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: clone diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 62c0922e3e3..9ade64910f2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -9,6 +9,10 @@ #ifndef COMMON_GLSLH #define COMMON_GLSLH +#ifdef DEBUG_MODE +#extension GL_EXT_debug_printf : enable +#endif + #define mul_2(x) ((x) << 1) #define mul_4(x) ((x) << 2) #define mul_8(x) ((x) << 3) @@ -82,4 +86,15 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) { return pack_into_int32(quantized); } +#ifdef DEBUG_MODE + +#define printf debugPrintfEXT + +void printVec4(vec4 texel) { + debugPrintfEXT( + "texel: %f, %f, %f, %f\\n", texel.x, texel.y, texel.z, texel.w); +} + +#endif // DEBUG_MODE + #endif // COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml index 39f96df5e90..36d0b879bdd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml @@ -6,6 +6,7 @@ concat_buffer: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: concat_1_buffer NUM_INPUTS: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl index afab0c524d6..0611defa4c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl @@ -113,8 +113,6 @@ void main() { VEC4_T out_texel = imageLoad(t_out, out_pos); - VEC4_T test_texel = VEC4_T(-1.0); - for (int comp = 0; comp < 4; ++comp) { ivec4 out_tidx = out_read_start_tidx; out_tidx[out_packed_dim] += comp; @@ -124,7 +122,6 @@ void main() { // of the previous input batch; if so, then don't overwrite this texel // element if (out_tidx[concat_dim] < concat_offset) { - test_texel[comp] = -5.0; continue; } @@ -164,7 +161,6 @@ void main() { inp${i}_packed_dim); out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w]; - test_texel[comp] = out_texel[comp]; continue; } else { diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml index ed5003382a1..d3de77d8ea9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml @@ -6,6 +6,7 @@ concat_texture: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: concat_1_texture3d NUM_INPUTS: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh index be8a76421a5..a3934422e27 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh @@ -14,7 +14,21 @@ #include "linear_fp_input_tile.glslh" VEC4_T load_fp_input_texel(const Conv2dTensorIndex tidx) { +#ifdef INPUT_BUFFER + VEC4_T texel = VEC4_T(0); + const int c_idx = mul_4(tidx.data.z); + const int c_stride = input_sizes.y * input_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * input_sizes.x + tidx.data.x; + const int limit = min(input_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; i++) { + texel[i] = t_fp_input[base_buf_i + i * c_stride]; + } + return texel; +#else return texelFetch(t_fp_input, tidx.data, 0); +#endif } void load_fp_input_tile( @@ -23,7 +37,9 @@ void load_fp_input_tile( #if TILE_M == 4 && TILE_K4 == 1 Conv2dTensorIndex load_tidx = block_idx_to_tensor_idx(block_idx); [[unroll]] for (int w = 0; w < TILE_M; w++) { - tile.data[w][0] = load_fp_input_texel(load_tidx); + if (load_tidx.data.x < input_sizes.x) { + tile.data[w][0] = load_fp_input_texel(load_tidx); + } load_tidx.data.x++; } #else diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl deleted file mode 100644 index 39aa9b11a0d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.glsl +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 in_sizes; - // Operates on (x, y, z) logical extents. - // channel_range is stored in range.w - ivec4 range; - // Analogus to range variable in copy. It defines the # of channel being - // copied. - // dst channel offset is stored in dst_offset.w - ivec4 dst_offset; - int src_channel_offset; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -void main() { - // Note: Unlike other shaders, the range is often not equal to the destination - // texture extent. - const ivec3 lpos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(lpos, range.xyz))) { - return; - } - - const ivec3 out_lpos = lpos + dst_offset.xyz; - - const ivec4 out_tidx = lpos_to_tidx(out_lpos, out_sizes, out_axis_map.w, packed_dim); - - // First read the existing values to make sure the boundary values stay. - VEC4_T v = load_texel_lpos(existing_out, out_lpos, out_axis_map); - - ivec4 in_tidx = out_tidx; - for (int i=0; i<4; i++) { - - in_tidx[packed_dim] = out_tidx[packed_dim] - dst_offset.w + i; - - // Handle the partial update for begining of channel in an existing tensor. - // If the source channel index is below zero or exceeds the range, we skip - // updating the element to avoid overwriting existing data. - if ((in_tidx[packed_dim] < 0) || (in_tidx[packed_dim] >= range.w)) { - continue; - } - - // Readjust for the source offset. - in_tidx[packed_dim] += src_channel_offset; - - ivec4 in_posi = tidx_to_posi(in_tidx, in_sizes, in_axis_map, packed_dim); - v[i] = load_texel(t_in, in_posi.xyz)[in_posi.w]; - } - - write_texel_lpos(t_out, out_lpos, v, out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml deleted file mode 100644 index 984d9a09d43..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ /dev/null @@ -1,12 +0,0 @@ -copy_channel_offset: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - shader_variants: - - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl deleted file mode 100644 index 178814a90c3..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -${define_active_storage_type(STORAGE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform restrict Block { - ivec3 range; - // xyz is source offset w is channel size - ivec4 src_offset; - // xyz is destination offset w is channel size - ivec4 dst_offset; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -${layout_declare_spec_const(C, "int", "batch_index_function", "0")} - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, range))) { - return; - } - - ivec3 in_pos = pos + src_offset.xyz; - ivec3 out_pos = pos + dst_offset.xyz; - if (src_offset.w > 0) { - if (batch_index_function == 1) { - // batch index is calculated using source channel size - const int channel_index = pos.z % src_offset.w; - const int batch_index = pos.z / src_offset.w; - out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w; - } else if (batch_index_function == 2) { - // batch index is calculated using destination channel size - const int channel_index = pos.z % dst_offset.w; - const int batch_index = pos.z / dst_offset.w; - in_pos.z = channel_index + src_offset.z + batch_index * src_offset.w; - } - } - - write_texel_lpos( - t_out, - out_pos, - load_texel_lpos(t_in, in_pos, in_axis_map), - out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml deleted file mode 100644 index 09f5ca36ea4..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ /dev/null @@ -1,17 +0,0 @@ -copy_offset: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - - VALUE: int8 - - VALUE: uint8 - STORAGE: - - VALUE: texture3d - - VALUE: texture2d - shader_variants: - - NAME: copy_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl deleted file mode 100644 index 3100565d08a..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_type(DTYPE)} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} - -layout(push_constant) uniform restrict Block { - ivec4 range; - - // xyz is source offset w is channel size - ivec4 src_offset; - - // xyz is destination offset w is channel size - ivec4 dst_offset; -}; - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, range.xyz))) { - return; - } - - // Position in input tensor - ivec3 in_pos = pos + src_offset.xyz; - in_pos[packed_dim] = pos[packed_dim] + (src_offset[packed_dim] >> 2); - - // Read input value mapping to this output texel - VEC4_T in_value = load_texel_lpos(t_in, in_pos, in_axis_map); - - // Starting offset to read from a texel - const int src_lane_offset = src_offset[packed_dim] & 0x3; - const bool has_src_lane_offset = src_lane_offset != 0; - - // If input lane offset is non zero i.e packed texel is composed from multiple sources - if (has_src_lane_offset) { - // Boundary values will come from next input texel in the packed dim. - ivec3 next_in_pos = in_pos; - next_in_pos[packed_dim] = in_pos[packed_dim] + 1; - VEC4_T next_value = load_texel_lpos(t_in, next_in_pos, in_axis_map); - - // Keep input values from the end of current input pixel based on src_lane_offset - // offset 1 means the first lane of current input texel is not a part of the output texel - // offset 2 means first 2 lanes are not and so on - // Copy next texel's values towards the end of input texel, based on lane offset - // offset 1 means the first lane from next texel is part of the input texel - // offset 2 means first 2 lanes from next texel is part of the input texel and so on - if (src_lane_offset == 1) { - in_value = ivec4(in_value.yzw, next_value.x); - } else if (src_lane_offset == 2) { - in_value = ivec4(in_value.zw, next_value.xy); - } else { - in_value = ivec4(in_value.w, next_value.xyz); - } - } - - // Starting offset to write at within a texel - const int out_lane_offset = dst_offset[packed_dim] & 0x3; - const bool has_dst_lane_offset = out_lane_offset != 0; - - ivec3 out_pos = pos + dst_offset.xyz; - out_pos[packed_dim] = pos[packed_dim] + (dst_offset[packed_dim] >> 2); - - VEC4_T out_value; - - // If lane offset is non zero i.e packed texel is composed from multiple sources - if (has_dst_lane_offset) { - // When position in packed dim is > 0 - if (pos[packed_dim] > 0) { - // Boundary values will come from previous input texel in the packed dim. - ivec3 prev_in_pos = in_pos; - prev_in_pos[packed_dim] = in_pos[packed_dim] - 1; - VEC4_T prev_value = load_texel_lpos(t_in, prev_in_pos, in_axis_map); - - // Shift values toward the beginning based on out_lane_offset - // offset 1 means the last lane from the previous texel is a part of the output texel - // offset 2 means last 2 lanes and so on - if (out_lane_offset == 1) { - out_value.x = prev_value.w; - } else if (out_lane_offset == 2) { - out_value.xy = prev_value.zw; - } else { - out_value.xyz = prev_value.yzw; - } - } else { - // When position in packed dim is == 0 - // Boundary values will be the previous texel values. - out_value = load_texel_lpos(existing_out, out_pos, out_axis_map); - } - - // Copy input values towards the end of output array, based on lane offset - // offset 1 means the first lane from previous texel is part of the output texel starting at offset - // offset 2 means first 2 lanes from the previous texel is part of the output texel and so on - if (out_lane_offset == 1) { - out_value.yzw = in_value.xyz; - } else if (out_lane_offset == 2) { - out_value.zw = in_value.xy; - } else { - out_value.w = in_value.x; - } - } else { - out_value = in_value; - } - - write_texel_lpos( - t_out, - out_pos, - out_value, - out_axis_map); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml deleted file mode 100644 index 6e55876cb28..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ /dev/null @@ -1,12 +0,0 @@ -copy_packed_dim_offset: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - STORAGE: texture3d - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: int32 - shader_variants: - - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh deleted file mode 100644 index 7194bebda35..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef DEQUANTIZE_GLSLH -#define DEQUANTIZE_GLSLH - -OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { - return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); -} - -#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl deleted file mode 100644 index 57dc2d53fff..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "int", "out_numel")} -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_out_sizes")} -${layout_declare_ubo(B, "ivec4", "t_out_strides")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "dequantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); -const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); - -/* - Dequantization Shader (Buffer Storage) - This shader converts n-bit integer tensor values back to floating-point representations - using pre-computed quantization parameters (scale and zero_point). The dequantization - reconstructs the original floating-point values from their discrete integer representations - with minimal precision loss. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - dequantize_per_tensor - This mode reverses the uniform quantization applied across the entire tensor by using the - single scale and zero_point values to convert quantized integer values back to their original - floating-point representation. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_per_token - This mode reverses the quantization applied individually to each token (or element) in the - input by using separate scale and zero_point values for each token. For a tensor of shape - [B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting - quantized values back to their original floating-point representation for each group of H - elements independently. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_per_channel - This mode reverses the quantization applied separately to each channel of the input tensor - by using distinct scale and zero_point values for each channel. For a tensor of shape - [B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C - channels, converting quantized values back to their original floating-point representation - independently for each channel. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_block_wise - This mode reverses the block-wise quantization applied to groups of elements by using separate - scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the - inverse affine transformation per block to convert quantized values back to their original - floating-point representation. For example, if the tensor shape is [6, 9, 4] and - blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements, - and dequantization is performed independently on each block. - - (*) global_wg_size: default - (*) local_wg_size: default - - Dequantization Formula: - value = (qvalue - zero_point) * scale -*/ - -#ifdef per_tensor - -void dequantize_per_tensor() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); - - t_out[out_bufi] = value; -} - -#elif defined(per_token) - -void dequantize_per_token() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - int token_idx = 0; - - if (t_out_sizes.w > 1) { - // 4D tensor - token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.z > 1) { - // 3D tensor - token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.y > 1) { - // 2D tensor - token_idx = out_tidx.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - OUT_T value = dequantize_val(qvalue, float(t_scale[token_idx]), int(t_zero_point[token_idx])); - - t_out[out_bufi] = value; -} - -#elif defined(per_channel) - -void dequantize_per_channel() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - // Calculate channel index based on the dequantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (tidx.x) - // axis 1 -> H dimension (tidx.y) - // axis 2 -> C dimension (tidx.z) - // axis 3 -> N dimension (tidx.w) - int channel_idx = 0; - - if (axis == 0) { - channel_idx = out_tidx.x; - } else if (axis == 1) { - channel_idx = out_tidx.y; - } else if (axis == 2) { - channel_idx = out_tidx.z; - } else if (axis == 3) { - channel_idx = out_tidx.w; - } - - channel_idx = min(channel_idx, num_channels - 1); - - OUT_T value = dequantize_val(qvalue, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); - - t_out[out_bufi] = value; -} - -#else // block_wise - -void dequantize_block_wise() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - const ivec4 bcoord = out_tidx / blockSize; - - const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - const OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); - - t_out[out_bufi] = value; -} - -#endif - -void main() { - dequantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml deleted file mode 100644 index a4375038a75..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ /dev/null @@ -1,31 +0,0 @@ -dequantize_buffer: - parameter_names_with_default_values: - IN_DTYPE: int32 - OUT_DTYPE: float - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - OUT_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: dequantize_per_tensor_buffer - MODE: per_tensor - - NAME: dequantize_per_token_buffer - MODE: per_token - - NAME: dequantize_per_channel_buffer - MODE: per_channel - - NAME: dequantize_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl deleted file mode 100644 index 19276cd8f7f..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} - -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_out_limits")} - -#include "indexing_utils.h" -#include "dequantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - * DEQUANTIZATION SHADER (TEXTURE STORAGE) - * - * This shader converts n-bit integer tensor values back to floating-point representations - * using pre-computed quantization parameters (scale and zero_point). The dequantization - * reconstructs the original floating-point values from their discrete integer representations - * with minimal precision loss. - * - * ALGORITHM: - * 1. Load quantized integer texel (4 values) from 3D texture - * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale - * 3. Store reconstructed floating-point texel to output texture - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with texel-based processing - * - Assumes width-packed layout (packed_dim = 0) for input/output textures - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - For per-token mode: scale/zero_point tensors must use buffer storage - * - Input/output textures: Must use standard axis mapping for per-token mode - * - * DEQUANTIZATION FORMULA VISUALIZATION: - * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: - * - * Integer Domain: Floating Point Domain: - * quant_min ──────────────► min_val - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * quant_max ──────────────► max_val - * - * Texel Dequantization Process: - * Input Texel: [-103, -128, -123, -96] (int4) - * Per-component dequantization with scale=0.1, zero_point=-128: - * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 - * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 - * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 - * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 - * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) - * - * PER-TENSOR DEQUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All texel components use same dequantization parameters - * - Parameters passed as push constants for efficiency - * - Each thread processes one texel (4 elements) independently - * - Formula: value[i] = (qvalue[i] - zero_point) * scale - * - * PER-TOKEN DEQUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates token_id from its 3D texture position - * - Scale/zero_point buffers accessed directly (not as textures) - * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] - * - * Token ID calculation for texel at position (x, y, z): - * - 3D tensor: token_id = z * texture_height + y - * - 2D tensor: token_id = y - * - 1D tensor: token_id = 0 - */ - -#ifdef per_tensor - -void dequantize_per_tensor() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - // Skip if out of bounds - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); - - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - write_texel(t_out, pos, outtex); -} - -#elif defined(per_token) - -void dequantize_per_token() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - - int token_idx = 0; - ivec3 dims = t_in_limits; - - if (dims.z > 1) { - // 3D tensor - token_idx = pos.z * dims.y + pos.y; - } else if (dims.y > 1) { - // 2D tensor - token_idx = pos.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = float(t_scale[token_idx]); - int zero_point_val = int(t_zero_point[token_idx]); - - FVEC4_T outtex; - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - - write_texel(t_out, pos, outtex); -} - -#elif defined(per_channel) - -void dequantize_per_channel() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - // Calculate channel index based on the dequantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (pos.x) - // axis 1 -> H dimension (pos.y) - // axis 2 -> C dimension (pos.z) - // axis 3 -> N dimension (batch folding in texture storage) - - if (axis == 0) { - // Width dimension - each texel component has different channel index - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - int channel_idx = pos.x * 4 + i; - channel_idx = min(channel_idx, num_channels - 1); - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 1) { - int channel_idx = pos.y; - channel_idx = min(channel_idx, num_channels - 1); - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 2) { - // Channel dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - int channel_idx = folded_idx % num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 3) { - // Batch dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - // In this case num_channels actually corresponds to the number of channels - // the C dimension N(C)HW - int channel_idx = folded_idx / num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } - - write_texel(t_out, pos, outtex); -} - -#else // block_wise - -void dequantize_block_wise() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) - return; - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); - int foldedZ = pos.z; - - int C_total = numBlocks.z * blockSize.z; - - [[unroll]] for (int i = 0; i < 4; ++i) { - ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); - - ivec4 bcoord = tidx / blockSize; - int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - - write_texel(t_out, pos, outtex); -} - -#endif - -void main() { - dequantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml deleted file mode 100644 index 7a58e9410d3..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ /dev/null @@ -1,31 +0,0 @@ -dequantize_texture: - parameter_names_with_default_values: - IN_DTYPE: int32 - OUT_DTYPE: float - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - OUT_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: dequantize_per_tensor_texture3d - MODE: per_tensor - - NAME: dequantize_per_token_texture3d - MODE: per_token - - NAME: dequantize_per_channel_texture3d - MODE: per_channel - - NAME: dequantize_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl index 8b519a67eb6..c1a21e44c60 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl @@ -19,7 +19,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl index ecfc10415a1..9a6295a8094 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl @@ -20,7 +20,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "common.glslh" #include "indexing.glslh" @@ -39,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) { indices_tidx.data.xyz = out_tidx.data.yzw; indices_tidx.data.w = 0; - TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple( - indices_tidx, indices); + TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple( + indices, indices_tidx); const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0); return in_texel[elem_pos.comp]; @@ -62,7 +61,7 @@ void main() { return; } - TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp); + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); const int embedding_idx = load_embedding_idx(out_tidx); const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x); diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml index 6d90e1fa8b1..887f7893061 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml @@ -6,5 +6,6 @@ expand_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: expand_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.yaml b/backends/vulkan/runtime/graph/ops/glsl/full.yaml index 1a5b0cb235e..5d7a983cae3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/full.yaml @@ -15,5 +15,6 @@ full: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: full diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl new file mode 100644 index 00000000000..318631a160f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_index", "int", "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "BufferMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int gather_dim = 0; + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + // Load the index value at the same position in the index tensor + const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx); + const int gather_idx = t_index[index_bufi]; + + // Construct the input tensor index by replacing the gather dimension + // with the gathered index value + TensorIndex input_tidx = out_tidx; + input_tidx.data[div_4(gather_dim)][mod_4(gather_dim)] = gather_idx; + + // Load from input tensor and store to output + const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx); + + t_out[out_bufi] = t_input[input_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml new file mode 100644 index 00000000000..8e2cff21b61 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +gather_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: gather_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl new file mode 100644 index 00000000000..71e352a7875 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} +${layout_declare_ubo(B, "TextureMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int gather_dim = 0; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + ivec4 idx_texel = texelFetch(t_index, out_pos, 0); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < 4; comp++) { + TensorIndex4D input_tidx = out_tidx; + int gather_idx = idx_texel[comp]; + input_tidx.data[gather_dim] = gather_idx; + + TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, input_tidx); + + VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0); + out_texel[comp] = input_texel[input_elem_pos.comp]; + + out_tidx.data[outp.packed_dim]++; + } + + imageStore(t_out, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml new file mode 100644 index 00000000000..dd38ecd0a7d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +gather_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: gather_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index abef2225cd9..6bf4c71a3c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -8,5 +8,6 @@ index_select: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index a306e3ce47d..716f7ecf2d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -8,5 +8,6 @@ index_select_channel: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index c4feb17ef2e..b9ac0e5dace 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -147,6 +147,20 @@ struct TensorIndex4D { ivec4 data; }; +TensorIndex4D zero_tensor4d_idx() { + TensorIndex4D tidx; + tidx.data = ivec4(0); + return tidx; +} + +bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes[0])); +} + +bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes)); +} + // // TextureElementIndex // @@ -245,41 +259,86 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) { tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1); } -TensorIndex4D zero_tensor4d_idx() { - TensorIndex4D tidx; - tidx.data = ivec4(0); - return tidx; -} - -// Does not account for axis mapping or batches -TensorIndex4D texture_pos_to_tensor_idx_simple( - const ivec3 pos, const TextureMetadata meta) { +// Does not account for axis mapping +TensorIndex4D texture_pos_to_tensor4d_idx_simple( + const TextureMetadata meta, const ivec3 pos) { TensorIndex4D tidx; tidx.data.xyz = pos; tidx.data.w = 0; tidx.data[meta.packed_dim] *= 4; + + // Compute batch idx accounting for batch concatenation, assuming channels as + // the concatenation dim. + if (meta.sizes.w > 1) { + int channels = meta.sizes.z; + if (meta.packed_dim == 2) { + channels = align_up_4(channels); + } + tidx.data.w = tidx.data.z / channels; + tidx.data.z = tidx.data.z % channels; + } return tidx; } -// Does not account for axis mapping or batches -TextureElementIndex tensor_idx_to_texture_element_idx_simple( - const TensorIndex4D tidx, const TextureMetadata meta) { +// Does not account for axis mapping +ivec3 tensor4d_idx_to_texel_pos_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { + ivec3 texel_pos; + + const int packed_dim_idx = tidx.data[meta.packed_dim]; + + texel_pos = tidx.data.xyz; + texel_pos[meta.packed_dim] = div_4(packed_dim_idx); + + // Account for batch concatenation, assuming channels as the concatenation dim + if (meta.sizes.w > 1) { + int channels_ntexels = meta.sizes.z; + if (meta.packed_dim == 2) { + channels_ntexels = div_up_4(channels_ntexels); + } + texel_pos.z += tidx.data.w * channels_ntexels; + } + + return texel_pos; +} + +// Does not account for axis mapping +TextureElementIndex tensor4d_idx_to_texture_element_idx_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { const int packed_dim_idx = tidx.data[meta.packed_dim]; TextureElementIndex tex_idx; tex_idx.pos = tidx.data.xyz; tex_idx.pos[meta.packed_dim] = div_4(packed_dim_idx); tex_idx.comp = mod_4(packed_dim_idx); + + // Account for batch concatenation, assuming channels as the concatenation dim + if (meta.sizes.w > 1) { + int channels_ntexels = meta.sizes.z; + if (meta.packed_dim == 2) { + channels_ntexels = div_up_4(channels_ntexels); + } + tex_idx.pos.z += tidx.data.w * channels_ntexels; + } + return tex_idx; } +uint tensor4d_idx_to_linear_idx( + const BufferMetadata meta, + const TensorIndex4D tidx) { + uint lin_idx = 0; + for (int d = 0; d < 4; ++d) { + lin_idx += meta.strides[0][d] * tidx.data[d]; + } + return lin_idx; +} + // // Debug utilities // #ifdef DEBUG_MODE -#extension GL_EXT_debug_printf : enable - void printTensorIndex(const TensorIndex tidx) { debugPrintfEXT( "TensorIndex: tidx=[%u %u %u %u %u %u %u %u]\\n", @@ -288,13 +347,21 @@ void printTensorIndex(const TensorIndex tidx) { ); } -void printTensorIndex4D(const TensorIndex tidx) { +void printTensorIndex4D(const TensorIndex4D tidx) { debugPrintfEXT( "TensorIndex4D: [%u, %u, %u, %u]\\n", - tidx.data[0][0], tidx.data[0][1], tidx.data[0][2], tidx.data[0][3] + tidx.data[0], tidx.data[1], tidx.data[2], tidx.data[3] ); } +void printTextureElementIndex(const TextureElementIndex tex_idx) { + debugPrintfEXT( + "TextureElementIndex: pos=[%d %d %d] comp=%d\\n", + tex_idx.pos.x, tex_idx.pos.y, tex_idx.pos.z, tex_idx.comp + ); +} + + void printBufferMetadata(const BufferMetadata meta) { debugPrintfEXT( "BufferMetadata: ndim=%u numel=%u\\n sizes=[%u %u %u %u %u %u %u %u]\\n dim_order=[%u %u %u %u %u %u %u %u]\\n strides=[%u %u %u %u %u %u %u %u]\\n", diff --git a/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml index 02afc3846a2..91306bd4cbf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml @@ -8,5 +8,7 @@ pad_channel: DTYPE: - VALUE: float - VALUE: half + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: pad_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml b/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml index dd74ec9cc28..2eb57291bb2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml @@ -8,5 +8,7 @@ pad_height_width: DTYPE: - VALUE: float - VALUE: half + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: pad_height_width diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.glsl similarity index 100% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml similarity index 58% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml index 37721db1ba8..e453214bc1a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -quantize_and_pack_linear_input: +quantize_and_pack_4h4w: parameter_names_with_default_values: DTYPE: float OUTPUT_STORAGE: texture3d @@ -12,13 +12,14 @@ quantize_and_pack_linear_input: STORAGE: texture3d GRANULARITY: per_tensor generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: half - VALUE: float shader_variants: - - NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d - - NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d - OUTPUT_STORAGE: buffer - - NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer - OUTPUT_STORAGE: buffer - INPUT_STORAGE: buffer + - NAME: quantize_and_pack_4h4w_per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl similarity index 100% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml similarity index 67% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml index 3fc66db2718..bdbc81c59d7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -quantize_and_pack_linear_input_with_sums: +quantize_and_pack_4h4w_with_group_sums: parameter_names_with_default_values: DTYPE: float OUTPUT_STORAGE: buffer @@ -16,14 +16,14 @@ quantize_and_pack_linear_input_with_sums: - VALUE: half - VALUE: float shader_variants: - - NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_texture3d - - NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_buffer + - NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_texture3d + - NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_buffer OUTPUT_STORAGE: buffer INPUT_STORAGE: buffer - - NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_texture3d + - NAME: quantize_and_pack_4h4w_with_group_sums_o4w16_buffer_texture3d NUM_GROUPS_PER_WG: 4 NUM_WORKERS_PER_GROUP: 16 - - NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_buffer + - NAME: quantize_and_pack_4h4w_with_group_sums_o4w16_buffer_buffer NUM_GROUPS_PER_WG: 4 NUM_WORKERS_PER_GROUP: 16 OUTPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl similarity index 98% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl index d485523709b..dfa0b5a95bf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl @@ -31,7 +31,7 @@ layout(std430) buffer; #include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE)} ${layout_declare_ubo(B, "ivec4", "input_sizes")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml similarity index 83% rename from backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml rename to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml index 712d3156e2e..dd6cd527e16 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -quantize_and_pack_q8ta_conv2d_input: +quantize_and_pack_4w4c: parameter_names_with_default_values: DTYPE: float OUTPUT_STORAGE: texture3d @@ -15,7 +15,8 @@ quantize_and_pack_q8ta_conv2d_input: combos: - parameter_values: [texture3d, texture3d] - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float shader_variants: - - NAME: quantize_and_pack_q8ta_conv2d_input + - NAME: quantize_and_pack_4w4c_per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl deleted file mode 100644 index 7bf3a932c6c..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "int", "out_numel")} -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_out_sizes")} -${layout_declare_ubo(B, "ivec4", "t_out_strides")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "quantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); -const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); - -/* - Quantization Shader (Buffer Storage) - This shader converts floating-point tensor values to n-bit integer representations - using pre-computed quantization parameters (scale and zero_point). The quantization - maps floating-point values to a discrete integer range while preserving the original - data distribution as much as possible. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - quantize_per_tensor - This mode applies uniform quantization across the entire tensor using a single scale - and zero_point value. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_token - This mode applies quantization individually to each token (or element) in the input, - using separate scale and zero_point values for each token. For instance if we have - a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_channel - This mode applies quantization separately to each channel of the input tensor, using - distinct scale and zero_point values for each channel. For example, if the tensor shape - is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing - each channel to be quantized independently. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_block_wise - This mode applies quantization in blocks or groups of elements, allowing different scale - and zero_point values for each block. It is equivalent to quantize_affine, where quantization - parameters are affine transformations applied per block. For example, if the tensor shape - is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. - - (*) global_wg_size: default - (*) local_wg_size: default - - Quantization Formula: - qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). -*/ - -#ifdef per_tensor - -void quantize_per_tensor() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); - - t_out[out_bufi] = qvalue; -} - -#elif defined(per_token) - -void quantize_per_token() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - int token_idx = 0; - - if (t_out_sizes.w > 1) { - // 4D tensor - token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.z > 1) { - // 3D tensor - token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.y > 1) { - // 2D tensor - token_idx = out_tidx.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - OUT_T qvalue = quantize_val(value, float(t_scale[token_idx]), int(t_zero_point[token_idx])); - - t_out[out_bufi] = qvalue; -} - -#elif defined(per_channel) - -void quantize_per_channel() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - // Calculate channel index based on the quantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (tidx.x) - // axis 1 -> H dimension (tidx.y) - // axis 2 -> C dimension (tidx.z) - // axis 3 -> N dimension (tidx.w) - int channel_idx = 0; - - if (axis == 0) { - channel_idx = out_tidx.x; - } else if (axis == 1) { - channel_idx = out_tidx.y; - } else if (axis == 2) { - channel_idx = out_tidx.z; - } else if (axis == 3) { - channel_idx = out_tidx.w; - } - - channel_idx = min(channel_idx, num_channels - 1); - - OUT_T qvalue = quantize_val(value, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); - - t_out[out_bufi] = qvalue; -} - -#else // block_wise - -void quantize_block_wise() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - const ivec4 bcoord = out_tidx / blockSize; - - const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - const OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); - - t_out[out_bufi] = qvalue; -} - -#endif - -void main() { - quantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml deleted file mode 100644 index fb5853ecd20..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ /dev/null @@ -1,31 +0,0 @@ -quantize_buffer: - parameter_names_with_default_values: - IN_DTYPE: float - OUT_DTYPE: int32 - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - OUT_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: quantize_per_tensor_buffer - MODE: per_tensor - - NAME: quantize_per_token_buffer - MODE: per_token - - NAME: quantize_per_channel_buffer - MODE: per_channel - - NAME: quantize_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl deleted file mode 100644 index 12e5769f50d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} - -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict BlockPC { - ivec4 blockSize; // WHCN - ivec4 numBlocks; // (#W,#H,#C,#N) - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_out_limits")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "quantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - Quantization Shader (Texture Storage) - This shader converts floating-point tensor values to n-bit integer representations - using pre-computed quantization parameters (scale and zero_point). The quantization - maps floating-point values to a discrete integer range while preserving the original - data distribution as much as possible. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - quantize_per_tensor - This mode applies uniform quantization across the entire tensor using a single scale - and zero_point value. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_token - This mode applies quantization individually to each token (or element) in the input, - using separate scale and zero_point values for each token. For instance if we have - a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_channel - This mode applies quantization separately to each channel of the input tensor, using - distinct scale and zero_point values for each channel. For example, if the tensor shape - is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing - each channel to be quantized independently. - - (*) global_wg_size: default - (*) local_wg_size: Default with special handling for batch dimension. When quantizing along - the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, - uses standard workgroup size derived from global workgroup dimensions. - - - quantize_block_wise - This mode applies quantization in blocks or groups of elements, allowing different scale - and zero_point values for each block. It is equivalent to quantize_affine, where quantization - parameters are affine transformations applied per block. For example, if the tensor shape - is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. - - (*) global_wg_size: default - (*) local_wg_size: Default with special handling for batch dimension. When quantizing along - the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, - uses standard workgroup size derived from global workgroup dimensions. - - Quantization Formula: - qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). -*/ - -#ifdef per_tensor - -void quantize_per_tensor() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); - outtex[i] = qvalue; - } - write_texel(t_out, pos, outtex); -} - -#elif defined(per_token) - -void quantize_per_token() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - - int token_idx = 0; - ivec3 dims = t_in_limits; - - if (dims.z > 1) { - // 3D tensor - token_idx = pos.z * dims.y + pos.y; - } else if (dims.y > 1) { - // 2D tensor - token_idx = pos.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = float(t_scale[token_idx]); - int zero_point_val = int(t_zero_point[token_idx]); - - IVEC4_T outtex; - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - - write_texel(t_out, pos, outtex); -} - -#elif defined(per_channel) - -void quantize_per_channel() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - // Calculate channel index based on the quantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component) - // axis 1 -> H dimension (pos.y) - // axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding - // axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding - - if (axis == 0) { - // Width dimension - each texel component has different channel index - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - int channel_idx = pos.x * 4 + i; - channel_idx = min(channel_idx, num_channels - 1); - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 1) { - // Height dimension - all texel components use same channel index - int channel_idx = pos.y; - channel_idx = min(channel_idx, num_channels - 1); - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 2) { - // Channel dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - int channel_idx = folded_idx % num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 3) { - // Batch dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual batch index from the folded dimension - int folded_idx = pos.z; - int batch_idx = folded_idx / num_channels; - - float scale_val = float(t_scale[batch_idx]); - int zero_point_val = int(t_zero_point[batch_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } - - write_texel(t_out, pos, outtex); -} - -#else // block_wise - -void quantize_block_wise() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) - return; - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); - int foldedZ = pos.z; - - int C_total = numBlocks.z * blockSize.z; - - [[unroll]] for (int i = 0; i < 4; ++i) { - ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); - - ivec4 bcoord = tidx / blockSize; - int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); - outtex[i] = qvalue; - } - - write_texel(t_out, pos, outtex); -} - -#endif - -void main() { - quantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml deleted file mode 100644 index 03d418ff2f7..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ /dev/null @@ -1,31 +0,0 @@ -quantize_texture: - parameter_names_with_default_values: - IN_DTYPE: float - OUT_DTYPE: int32 - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - OUT_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: quantize_per_tensor_texture3d - MODE: per_tensor - - NAME: quantize_per_token_texture3d - MODE: per_token - - NAME: quantize_per_channel_texture3d - MODE: per_channel - - NAME: quantize_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl index 5d377cf1f1a..af5f5f661e7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl @@ -57,8 +57,6 @@ $else: $if OUTPUT_IS_INDICES: #define OUTPUT_IS_INDICES -#extension GL_EXT_debug_printf : require - void main() { const uint out_bufi = gl_GlobalInvocationID.y; diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml index 4147e82965a..c48237f7568 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml @@ -6,5 +6,7 @@ repeat_channel: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: repeat_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml index 5c284a580c9..f56172dc7f0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml @@ -6,5 +6,7 @@ repeat_interleave: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: repeat_interleave diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl index 30375728921..155eda467c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl @@ -13,23 +13,29 @@ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} ${define_required_extensions(DTYPE)} +${define_active_storage_type(STORAGE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "xqout_limits")} -${layout_declare_ubo(B, "ivec3", "xkout_limits")} +#include "indexing.glslh" -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)} -layout(constant_id = 3) const int packed_dim = 0; +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "xqout")} + ${layout_declare_ubo(B, "BufferMetadata", "xkout")} + ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "xqout")} + ${layout_declare_ubo(B, "TextureMetadata", "xkout")} + ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")} -#include "indexing_utils.h" +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * This shader computes rotary positional embeddings which are used in the Llama @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0; * 1. xq (batch_size, sequence_len, num_heads, head_dim) * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) * 3. freqs_cos (sequence_len, head_dim / 2) - * 4. freqs_cos (sequence_len, head_dim / 2) + * 4. freqs_sin (sequence_len, head_dim / 2) * * Two output tensors are produced, with the same shapes as xq and xk * respectively. @@ -66,23 +72,43 @@ void main() { // Each thread will write to two output locations to maximize data re-use. // One texel loaded from the freqs_cos/freqs_sin tensors can be used to // calculate two output texels. - const ivec3 x_pos_1 = ivec3( - gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); - const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); + TensorIndex4D out_tidx_1 = zero_tensor4d_idx(); + out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8; + out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz); + + TensorIndex4D out_tidx_2 = out_tidx_1; + out_tidx_2.data.x += 4; - if (any(greaterThanEqual(x_pos_2, xqout_limits))) { + if (out_of_bounds(out_tidx_2, xqout)) { return; } - const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); + TensorIndex4D freqs_tidx = zero_tensor4d_idx(); + freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4; + freqs_tidx.data.y = out_tidx_1.data.z; - VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); - VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); +#ifdef USING_BUFFER + const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx)); + VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi]; + VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi]; - // Compute xqout + uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1)); + uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2)); + VEC4_T x_tex_1 = t_xq[x_texel_bufi_1]; + VEC4_T x_tex_2 = t_xq[x_texel_bufi_2]; + +#else // USING_TEXTURE + const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx); + VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0); + VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0); - VEC4_T x_tex_1 = load_texel(xq, x_pos_1); - VEC4_T x_tex_2 = load_texel(xq, x_pos_2); + const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1); + const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2); + VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0); + VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0); +#endif + + // Compute xqout // Separate into even and odd elements VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); @@ -94,20 +120,34 @@ void main() { VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xqout, x_pos_1, xout_tex_1); - write_texel(xqout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xqout[x_texel_bufi_1] = xout_tex_1; + t_xqout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xqout, x_pos_1, xout_tex_1); + imageStore(t_xqout, x_pos_2, xout_tex_2); +#endif // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout // may have a larger height dim than xk and xkout. Only compute xkout if this // invocation is still within bounds. - if (any(greaterThanEqual(x_pos_2, xkout_limits))) { + if (out_of_bounds(out_tidx_2, xkout)) { return; } // Compute xkout - x_tex_1 = load_texel(xk, x_pos_1); - x_tex_2 = load_texel(xk, x_pos_2); +#ifdef USING_BUFFER + x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1)); + x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2)); + + x_tex_1 = t_xk[x_texel_bufi_1]; + x_tex_2 = t_xk[x_texel_bufi_2]; + +#else // USING_TEXTURE + x_tex_1 = texelFetch(t_xk, x_pos_1, 0); + x_tex_2 = texelFetch(t_xk, x_pos_2, 0); +#endif x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); @@ -118,6 +158,11 @@ void main() { xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xkout, x_pos_1, xout_tex_1); - write_texel(xkout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xkout[x_texel_bufi_1] = xout_tex_1; + t_xkout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xkout, x_pos_1, xout_tex_1); + imageStore(t_xkout, x_pos_2, xout_tex_2); +#endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml index a81fd564d10..ba8aa400958 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml @@ -3,6 +3,9 @@ rotary_embedding: DTYPE: float STORAGE: texture3d generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index b780cdce6fe..5f7e4c2719d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -14,10 +14,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE - -#extension GL_EXT_debug_printf : enable - #include "common.glslh" ${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/select.glslh b/backends/vulkan/runtime/graph/ops/glsl/select.glslh index 6509015b4b6..5390e2a4bb2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/select.glslh @@ -9,70 +9,87 @@ #ifndef SELECT_GLSLH #define SELECT_GLSLH -#ifndef USING_BUFFER +#ifdef USING_BUFFER /* - * Enable the fast path if a texel loaded from the input texture can be used as - * is to store to the output texture. The following conditions must be met: + * Converts output tensor indices to input tensor indices for the select operation + * on buffer storage. * - * 1. The input and output textures have the same packed dimension. - * 2. The selected_dim must not be the packed dimension of the input. - * 3. The packed dimension of the input must "map" to the packed dimension of - * the output. This occurs if selected_dim is greater than the packed dimension - * of the input. + * This is done by "inserting" the select index at the selected_dim in the input + * tensor index. + * + * Parameters assumed to be defined: + * - inp: BufferMetadata + * - selected_dim + * - index */ -bool can_use_fast_path() { - if (out_packed_dim != in_packed_dim) { - return false; +TensorIndex out_tidx_to_in_tidx(const TensorIndex out_tidx) { + TensorIndex in_tidx; + initialize(in_tidx); + + int in_size = int(size_at(inp, selected_dim)); + int adjusted_index = index; + if (index < 0) { + adjusted_index = index + in_size; } - if (selected_dim <= in_packed_dim) { - return false; + + // Copy indices before selected_dim + for (int d = 0; d < selected_dim; d++) { + in_tidx.data[div_4(d)][mod_4(d)] = idx_at(out_tidx, d); } - return true; + + // Insert the selected index + in_tidx.data[div_4(selected_dim)][mod_4(selected_dim)] = adjusted_index; + + // Copy indices after selected_dim (shifted by 1) + for (int d = selected_dim; d < int_ndim(inp) - 1; d++) { + in_tidx.data[div_4(d + 1)][mod_4(d + 1)] = idx_at(out_tidx, d); + } + + return in_tidx; } -#endif // USING_BUFFER +#else // texture storage /* - * Given an output tensor index, return the corresponding input tensor index for - * the select operator. This is done by "inserting" the select index at the - * selected_dim in the input tensor index. + * Converts output tensor indices to input tensor indices for the select operation + * on texture storage. * - * A simple example is (note all tensor index are in WHCN order): - * out_tidx = [7, 5, 9] - * selected_dim = 2 - * index = 3 - * in_tidx = [7, 3, 5, 9] + * This is done by "inserting" the select index at the selected_dim in the input + * tensor index. * - * This function assumes that the following variables are defined in the layout: - * - in_sizes + * Parameters assumed to be defined: + * - inp: TextureMetadata * - selected_dim * - index */ -ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { - ivec4 in_tidx = ivec4(0); +TensorIndex4D out_tidx_to_in_tidx(const TensorIndex4D out_tidx) { + TensorIndex4D in_tidx; + in_tidx.data = ivec4(0); int adjusted_index = index; if (index < 0) { - adjusted_index = index + in_sizes[selected_dim]; + adjusted_index = index + inp.sizes[selected_dim]; } // Handle different dimensions for selection if (selected_dim == 0) { // Select from width dimension - in_tidx = ivec4(adjusted_index, out_tidx.x, out_tidx.y, out_tidx.z); + in_tidx.data = ivec4(adjusted_index, out_tidx.data.x, out_tidx.data.y, out_tidx.data.z); } else if (selected_dim == 1) { // Select from height dimension - in_tidx = ivec4(out_tidx.x, adjusted_index, out_tidx.y, out_tidx.z); + in_tidx.data = ivec4(out_tidx.data.x, adjusted_index, out_tidx.data.y, out_tidx.data.z); } else if (selected_dim == 2) { // Select from channel dimension - in_tidx = ivec4(out_tidx.x, out_tidx.y, adjusted_index, out_tidx.z); + in_tidx.data = ivec4(out_tidx.data.x, out_tidx.data.y, adjusted_index, out_tidx.data.z); } else if (selected_dim == 3) { // Select from batch dimension - in_tidx = ivec4(out_tidx.x, out_tidx.y, out_tidx.z, adjusted_index); + in_tidx.data = ivec4(out_tidx.data.x, out_tidx.data.y, out_tidx.data.z, adjusted_index); } return in_tidx; } +#endif // USING_BUFFER + #endif // SELECT_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice.glslh b/backends/vulkan/runtime/graph/ops/glsl/slice.glslh index 87325754f4d..0a815c85d66 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/slice.glslh @@ -9,49 +9,61 @@ #ifndef SLICE_GLSLH #define SLICE_GLSLH -#ifndef USING_BUFFER +#include "indexing.glslh" -/** - * Enable the fast path if a texel loaded from the input texture can be used as - * is to store to the output texture. The following conditions must be met: +#ifdef USING_BUFFER + +/* + * Converts output tensor indices to input tensor indices for the slice operation + * on buffer storage. * - * 1. The input and output textures have the same packed dimension. - * 2. The select_dim must not be the packed dimension of the input. + * Parameters assumed to be defined: + * - inp: BufferMetadata + * - selected_dim + * - start + * - step */ -bool can_use_fast_path() { - if (out_packed_dim != in_packed_dim) { - return false; - } - if (in_packed_dim == selected_dim) { - return false; +TensorIndex out_tidx_to_in_tidx(const TensorIndex out_tidx) { + TensorIndex in_tidx = out_tidx; + + int in_size = int(size_at(inp, selected_dim)); + int adjusted_start = start; + if (start < 0) { + adjusted_start = start + in_size; } - return true; + + uint out_idx = idx_at(out_tidx, selected_dim); + in_tidx.data[div_4(selected_dim)][mod_4(selected_dim)] = + adjusted_start + int(out_idx) * step; + + return in_tidx; } -#endif // USING_BUFFER +#else // texture storage /* - * Converts output tensor indices to input tensor indices for the slice operation. - * This function maps the output indices to the corresponding input indices based on - * the slice parameters (start, step, selected_dim). + * Converts output tensor indices to input tensor indices for the slice operation + * on texture storage. * - * Parameters assumed to be defined in the layout specifier: - * - in_sizes + * Parameters assumed to be defined: + * - inp: TextureMetadata * - selected_dim * - start * - step */ -ivec4 out_tidx_to_in_tidx(const ivec4 out_tidx) { - ivec4 in_tidx = out_tidx; +TensorIndex4D out_tidx_to_in_tidx(const TensorIndex4D out_tidx) { + TensorIndex4D in_tidx = out_tidx; int adjusted_start = start; if (start < 0) { - adjusted_start = start + in_sizes[selected_dim]; + adjusted_start = start + inp.sizes[selected_dim]; } - in_tidx[selected_dim] = adjusted_start + out_tidx[selected_dim] * step; + in_tidx.data[selected_dim] = adjusted_start + out_tidx.data[selected_dim] * step; return in_tidx; } +#endif // USING_BUFFER + #endif // SLICE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.glsl new file mode 100644 index 00000000000..0505c9e7bcd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.glsl @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int split_dim = 0; +layout(constant_id = 4) const int split_idx = 0; +layout(constant_id = 5) const int split_offset = 0; + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + TensorIndex input_tidx = out_tidx; + input_tidx.data[div_4(split_dim)][mod_4(split_dim)] += split_offset; + + const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx); + + t_out[out_bufi] = t_input[input_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.yaml new file mode 100644 index 00000000000..45dbff832f9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_buffer.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +split_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: split_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/split_texture.glsl new file mode 100644 index 00000000000..92d7ce548e2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_texture.glsl @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int split_dim = 0; +layout(constant_id = 4) const int split_idx = 0; +layout(constant_id = 5) const int split_offset = 0; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + + TensorIndex4D input_tidx = out_tidx; + input_tidx.data[split_dim] += split_offset; + + for (int comp = 0; comp < limit; comp++) { + TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, input_tidx); + + VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0); + out_texel[comp] = input_texel[input_elem_pos.comp]; + + input_tidx.data[outp.packed_dim]++; + } + + imageStore(t_output, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/split_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/split_texture.yaml new file mode 100644 index 00000000000..6a1613a401e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/split_texture.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +split_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: split_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl index 7605c59c72f..73b753ccc0b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_buffer.glsl @@ -11,18 +11,23 @@ #define PRECISION ${PRECISION} #define UBO_PARAMS ${UBO_PARAMS} -#define VEC4_T ${texel_type(DTYPE)} #define T ${buffer_scalar_type(DTYPE)} ${define_active_storage_type("buffer")} ${define_required_extensions(DTYPE)} +#extension GL_EXT_control_flow_attributes : require + layout(std430) buffer; -#include "indexing_utils.h" +#include "indexing.glslh" + ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + $if UBO_PARAMS: $if OP_NAME == "slice": ${layout_declare_ubo(B, "int", "start")} @@ -32,10 +37,6 @@ $if UBO_PARAMS: ${layout_declare_ubo(B, "int", "index")} layout(push_constant) uniform restrict Block { - ivec4 in_sizes; - ivec4 out_strides; - ivec4 in_strides; - int out_numel; int selected_dim; $if not UBO_PARAMS: $if OP_NAME == "slice": @@ -46,24 +47,19 @@ layout(push_constant) uniform restrict Block { int index; }; -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); - layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "${OP_NAME}.glslh" void main() { - const int out_bufi = ivec3(gl_GlobalInvocationID).x; - if (out_bufi >= out_numel) { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { return; } - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); - ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + TensorIndex in_tidx = out_tidx_to_in_tidx(out_tidx); - const int in_bufi = tidx_to_bufi(in_tidx, in_strides); + const uint in_bufi = tensor_idx_to_linear_idx(inp, in_tidx); t_out[out_bufi] = t_in[in_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl index 0f34713cb43..d2c9c025242 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/transfer_texture.glsl @@ -11,19 +11,25 @@ #define PRECISION ${PRECISION} #define UBO_PARAMS ${UBO_PARAMS} -#define VEC4_T ${texel_type(DTYPE)} -#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} ${define_active_storage_type("texture3d")} ${define_required_extensions(DTYPE)} +#extension GL_EXT_control_flow_attributes : require + layout(std430) buffer; -#include "indexing_utils.h" +#include "common.glslh" +#include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} ${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + $if UBO_PARAMS: $if OP_NAME == "slice": ${layout_declare_ubo(B, "int", "start")} @@ -33,8 +39,6 @@ $if UBO_PARAMS: ${layout_declare_ubo(B, "int", "index")} layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 in_sizes; int selected_dim; $if not UBO_PARAMS: $if OP_NAME == "slice": @@ -45,48 +49,33 @@ layout(push_constant) uniform restrict Block { int index; }; -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); -const lowp int out_packed_dim = unhash_packed_dim(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); -const lowp int in_packed_dim = unhash_packed_dim(in_layout); - layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #include "${OP_NAME}.glslh" void main() { - const ivec3 lpos = ivec3(gl_GlobalInvocationID); - ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_tidx, out_sizes))) { + if (out_of_bounds(out_pos, outp)) { return; } - if (can_use_fast_path()) { - ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); - ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); - VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + VEC4_T out_texel = VEC4_T(0); - write_texel_lpos(t_out, lpos, in_texel, out_axis_map); - } - else { - VEC4_T out_texel = VEC4_T(0); - for (int texel_i = 0; texel_i < 4; ++texel_i) { - ivec4 in_tidx = out_tidx_to_in_tidx(out_tidx); - ivec3 in_pos = tidx_to_pos(in_tidx, in_sizes, in_axis_map, in_packed_dim); - int element_idx = in_tidx[in_packed_dim] % 4; - - VEC4_T in_texel = VEC4_T(load_texel(t_in, in_pos)); - T selected_value = T(in_texel[element_idx]); + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < limit; comp++) { + TensorIndex4D in_tidx = out_tidx_to_in_tidx(out_tidx); - out_texel[texel_i] = selected_value; + TextureElementIndex in_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, in_tidx); - out_tidx[out_packed_dim]++; - } + VEC4_T in_texel = texelFetch(t_in, in_elem_pos.pos, 0); + out_texel[comp] = in_texel[in_elem_pos.comp]; - write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + out_tidx.data[outp.packed_dim]++; } + + imageStore(t_out, out_pos, out_texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl similarity index 86% rename from backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl rename to backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl index 798366b523a..be0a39bac3c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl @@ -30,7 +30,7 @@ layout(std430) buffer; #include "conv2d_common.glslh" -${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE)} ${layout_declare_tensor(B, "r", "t_packed_int8_output", "int", INPUT_STORAGE, is_scalar_array=False)} ${layout_declare_ubo(B, "ivec4", "output_sizes")} @@ -84,7 +84,19 @@ void unpack_and_dequantize( void store_fp_output_texel( const Conv2dTensorIndex tidx, const VEC4_T out_texel) { +#ifdef OUTPUT_BUFFER + const int c_idx = mul_4(tidx.data.z); + const int c_stride = output_sizes.y * output_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * output_sizes.x + tidx.data.x; + const int limit = min(output_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; ++i) { + t_fp_output[base_buf_i + i * c_stride] = out_texel[i]; + } +#else imageStore(t_fp_output, tidx.data, out_texel); +#endif } void store_fp_tile( @@ -92,7 +104,9 @@ void store_fp_tile( const Conv2dBlockIndex block_idx) { Conv2dTensorIndex store_tidx = block_idx_to_tensor_idx(block_idx); [[unroll]] for (int w = 0; w < 4; w++) { - store_fp_output_texel(store_tidx, block.data[w][0]); + if (store_tidx.data.x < output_sizes.x) { + store_fp_output_texel(store_tidx, block.data[w][0]); + } store_tidx.data.x++; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml similarity index 82% rename from backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml rename to backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml index 24b253da343..0a419e632e3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -unpack_and_dequantize_q8ta_conv2d_output: +unpack_4w4c_and_dequantize: parameter_names_with_default_values: DTYPE: float OUTPUT_STORAGE: texture3d @@ -15,7 +15,8 @@ unpack_and_dequantize_q8ta_conv2d_output: combos: - parameter_values: [texture3d, texture3d] - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float shader_variants: - - NAME: unpack_and_dequantize_q8ta_conv2d_output + - NAME: unpack_4w4c_and_dequantize_per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl index 2c02803a9b1..96b9aa85a1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + /* * The insight behind the view operation is that the contiguous index of each * tensor element in the input and output tensors are the same. @@ -28,17 +30,20 @@ void main() { return; } - TensorIndex outp_tidx; - linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + uint inp_bufi = outp_bufi; + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); - // To map the output to the input, find the input element that has the same - // contiguous index as the output element. - const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); - TensorIndex inp_tidx; - contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); - const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl new file mode 100644 index 00000000000..a926c9fea11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl @@ -0,0 +1,54 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)} +${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + +/* + * The insight behind the view_convert operation is that the contiguous index of each + * tensor element in the input and output tensors are the same, but the data types + * may be different and need conversion. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + uint inp_bufi = outp_bufi; + + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } + + // Convert data type from input to output + t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml new file mode 100644 index 00000000000..11d56cad4a9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: buffer + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [int32, half] + - parameter_values: [uint8, float] + - parameter_values: [uint8, half] + - parameter_values: [uint8, int32] + shader_variants: + - NAME: view_convert_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index a4a96ffdb88..a36660e0aca 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -34,150 +34,6 @@ void resize_choose_qparams_per_row( graph->virtual_resize(input_zeros, new_sizes); } -utils::uvec3 choose_qparams_pick_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - // For per-tensor quantization, we want a single workgroup that can handle - // all elements with proper reduction. The shader uses NWORKERS=64 threads. - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For buffer storage, use a single workgroup in X dimension - // The shader will handle strided access across all elements - return {1u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_global_wg_size(args.at(0).refs.at(0)); - } -} - -utils::uvec3 choose_qparams_pick_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For buffer storage, use 64 threads in X dimension to match NWORKERS - // This ensures the shared memory arrays are properly sized - return {64u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_local_wg_size(global_workgroup_size); - } -} - -utils::uvec3 choose_qparams_per_token_pick_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For per-token quantization, we need one workgroup per token - // Calculate number of tokens (product of all dimensions except the last - // one) - const auto input_sizes = graph->sizes_of(input); - int64_t num_tokens = 1; - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - return {static_cast(num_tokens), 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_global_wg_size(args.at(0).refs.at(0)); - } -} - -utils::uvec3 choose_qparams_per_token_pick_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - return {1u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_local_wg_size(global_workgroup_size); - } -} - -utils::uvec3 choose_qparams_block_wise_pick_global_wg_size( - ComputeGraph* g, - const vkapi::ShaderInfo&, - const std::vector& a, - const std::vector& r) { - const ValueRef input = a.at(2).refs.at(0); - const auto blkRef = r.at(0); - const auto inSz = g->sizes_of(input); - const auto blkList = g->get_int_list(blkRef); - - // Use same code as in add_choose_qparams_block_wise_node - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*blkList); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(inSz); - - // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) - utils::ivec4 nBlk = { - (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], - (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], - (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], - (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; - - uint32_t nBlocks = nBlk[0] * nBlk[1] * nBlk[2] * nBlk[3]; - - // For texture storage, use more threads to better utilize GPU parallelism - // Each thread can process multiple blocks with stride - if (g->is_buffer_storage(input)) { - return {nBlocks, 1u, 1u}; - } else { - // For texture storage, use more workgroups to better utilize GPU - // Aim for ~64-256 threads per workgroup for good occupancy - uint32_t preferred_threads_per_wg = 64; - uint32_t num_workgroups = - (nBlocks + preferred_threads_per_wg - 1) / preferred_threads_per_wg; - num_workgroups = std::max(1u, std::min(num_workgroups, nBlocks)); - return {num_workgroups * preferred_threads_per_wg, 1u, 1u}; - } -} - -utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( - ComputeGraph* g, - const vkapi::ShaderInfo&, - const utils::uvec3& global_wg_size, - const std::vector& a, - const std::vector&) { - const ValueRef input = a.at(2).refs.at(0); - - if (g->is_buffer_storage(input)) { - return {1u, 1u, 1u}; - } else { - // For texture storage, use 64 threads per workgroup for better occupancy - uint32_t local_size = std::min(64u, global_wg_size[0]); - return {local_size, 1u, 1u}; - } -} - vkapi::ShaderInfo pick_choose_qparams_per_row_shader( ComputeGraph* graph, const std::vector& args, @@ -222,160 +78,6 @@ utils::uvec3 pick_choose_qparams_per_row_local_wg_size( return {workers_per_output, outputs_per_wg, 1u}; } -void add_choose_qparams_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& eps, - const ValueRef& scale_out, - const ValueRef& zero_point_out) { - std::string kernel_name("choose_qparams_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(zero_point_out)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - float eps_val = static_cast(graph.get_double(eps)); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zero_point_out), - graph.strides_ubo(zero_point_out)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(scale_out), - graph.logical_limits_ubo(zero_point_out)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - PushConstantDataInfo(&eps_val, sizeof(float)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_pick_global_wg_size, - choose_qparams_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zero_point_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_choose_qparams_per_token_asymmetric_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale_out, - const ValueRef& zero_point_out) { - std::string kernel_name("choose_qparams_per_token_asymmetric"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - int num_tokens_val = static_cast(num_tokens); - int quant_min_val = -128; // Fixed for asymmetric quantization - int quant_max_val = 127; // Fixed for asymmetric quantization - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zero_point_out), - graph.strides_ubo(zero_point_out)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(scale_out), - graph.logical_limits_ubo(zero_point_out)}; - } - - push_constants = { - PushConstantDataInfo(&num_tokens_val, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_per_token_pick_global_wg_size, - choose_qparams_per_token_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zero_point_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - void add_choose_qparams_per_row_node( ComputeGraph& graph, const ValueRef& input, @@ -427,221 +129,6 @@ void add_choose_qparams_per_row_node( resize_choose_qparams_per_row)); } -void add_choose_qparams_block_wise_node( - ComputeGraph& graph, - ValueRef input, - ValueRef block_size, - int mapping_type, // 0 / 1 / 2 - ValueRef quant_min, - ValueRef quant_max, - ValueRef eps, - ValueRef scale_out, - ValueRef zp_out) { - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // For shader compatibility, we still need to convert to WHCN order - // but the output shape calculation is now handled correctly in resize - // function - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) - utils::ivec4 num_blocks_vec = { - (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], - (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], - (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], - (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - // Handle optional quant_min and quant_max parameters - int qmin, qmax; - if (graph.val_is_none(quant_min) || graph.val_is_none(quant_max)) { - // Use default values based on target_dtype (similar to - // _get_and_check_qmin_qmax) For now, assume int8 range as default - this - // should match the Python implementation - qmin = -128; - qmax = 127; - } else { - qmin = static_cast(graph.get_int(quant_min)); - qmax = static_cast(graph.get_int(quant_max)); - } - - float eps_val; - if (graph.val_is_none(eps)) { - // Use default eps value (similar to Python implementation) - eps_val = 1.192092896e-07f; // torch.finfo(torch.float32).eps - } else { - eps_val = static_cast(graph.get_double(eps)); - } - - // Create push constants vector - std::vector push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&mapping_type, sizeof(int)), - PushConstantDataInfo(&qmin, sizeof(int)), - PushConstantDataInfo(&qmax, sizeof(int)), - PushConstantDataInfo(&eps_val, sizeof(float))}; - - std::string kernel_name("choose_qparams_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zp_out)); - - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zp_out), - graph.strides_ubo(zp_out)}; - } else { - // For texture input, the shader uses buffer storage for outputs - // so we need buffer UBOs for the output tensors - param_ubos = { - graph.logical_limits_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zp_out), - graph.strides_ubo(zp_out)}; - } - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_block_wise_pick_global_wg_size, - choose_qparams_block_wise_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zp_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {block_size}, - // Resizing Logic - nullptr)); -} - -void choose_qparams_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef eps = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; - - ValueRef scale_out = kDummyValueRef; - ValueRef zero_point_out = kDummyValueRef; - - { - const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - scale_out = out_tuple->at(0); - zero_point_out = out_tuple->at(1); - } - - // Void the unused dtype parameter to match ATen signature - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point output dtypes - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); -} - -void choose_qparams_per_token_asymmetric_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; - - ValueRef scale_out = kDummyValueRef; - ValueRef zero_point_out = kDummyValueRef; - - { - const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - scale_out = out_tuple->at(0); - zero_point_out = out_tuple->at(1); - } - - // Void the unused parameter to match ATen signature - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point output dtypes - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - add_choose_qparams_per_token_asymmetric_node( - graph, input, scale_out, zero_point_out); -} - bool can_use_choose_qparams_per_row( ComputeGraph& graph, const ValueRef input, @@ -674,11 +161,13 @@ void choose_qparams_affine_impl( int arg_idx = 0; const ValueRef input = args[arg_idx++]; const ValueRef mapping_type = args[arg_idx++]; + (void)mapping_type; const ValueRef block_size = args[arg_idx++]; const ValueRef target_dtype = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; const ValueRef eps = args[arg_idx++]; + (void)eps; const ValueRef scale_dtype = args[arg_idx++]; const ValueRef zero_point_dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; @@ -704,59 +193,7 @@ void choose_qparams_affine_impl( graph, input, quant_min, quant_max, scale_out, zero_point_out); } - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point dtypes from arguments - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - std::string mapping_type_str = graph.get_string(mapping_type); - int mapping_type_val = 0; // Default to ASYMMETRIC - - if (mapping_type_str == "ASYMMETRIC" || mapping_type_str.empty()) { - mapping_type_val = 0; // ASYMMETRIC - } else if (mapping_type_str == "SYMMETRIC") { - mapping_type_val = 1; - } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { - mapping_type_val = 2; - } else { - VK_THROW("Unsupported mapping_type: ", mapping_type_str); - } - - add_choose_qparams_block_wise_node( - graph, - input, - block_size, - mapping_type_val, - quant_min, - quant_max, - eps, - scale_out, - zero_point_out); + VK_THROW("Unsupported input case for choose_qparams_affine"); } void choose_qparams_per_row( @@ -769,27 +206,11 @@ void choose_qparams_per_row( const ValueRef input_scales = args[arg_idx++]; const ValueRef input_zps = args[arg_idx++]; - // ValueRef scale_out = kDummyValueRef; - // ValueRef zero_point_out = kDummyValueRef; - // - // { - // const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - // scale_out = out_tuple->at(0); - // zero_point_out = out_tuple->at(1); - // } - // - add_choose_qparams_per_row_node( graph, input, quant_min, quant_max, input_scales, input_zps); } REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.choose_qparams_per_token_asymmetric.default, - choose_qparams_per_token_asymmetric_impl); - // Register the per-channel quantization operator VK_REGISTER_OP(etvk.choose_qparams_per_row.default, choose_qparams_per_row); diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp deleted file mode 100644 index bd648dbae2d..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -#include -#include -#include - -namespace vkcompute { - -using utils::ivec3; -using utils::ivec4; -using utils::uvec3; - -void add_copy_offset_node( - ComputeGraph& graph, - const ValueRef in, - const ivec3& range, - const ivec4& src_offset, - const ivec4& dst_offset, - const ValueRef out, - bool calc_out_pos_using_src_chnl, - bool calc_in_pos_using_dst_chnl) { - std::string kernel_name = "copy_offset"; - kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - - auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - { - PushConstantDataInfo(&range, sizeof(range), sizeof(ivec4)), - PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)), - PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)), - }, - // Specialization Constants - {graph.hashed_layout_of(out), - graph.hashed_layout_of(in), - (calc_out_pos_using_src_chnl ? 1 - : calc_in_pos_using_dst_chnl ? 2 - : 0)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_copy_packed_dim_offset_node( - ComputeGraph& graph, - const ValueRef in, - const ivec3& range, - const ivec4& src_offset, - const ivec4& dst_offset, - const ValueRef out) { - // Check the packed dimension is same for both tensors, also check if the - // packed dimension is Width or Height. Since the function does not support - // channel packing. - VK_CHECK_COND( - graph.packed_dim_of(in) == graph.packed_dim_of(out) && - (graph.packed_dim_of(in) == WHCN::kWidthDim || - graph.packed_dim_of(in) == WHCN::kHeightDim)); - - std::string kernel_name = "copy_packed_dim_offset"; - kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - const std::vector in_sizes = graph.sizes_of(in); - const std::vector out_sizes = graph.sizes_of(out); - - // A copy of range with the last element set to batch size of the input tensor - ivec4 final_range = { - range[0], range[1], range[2], dim_at(in_sizes, kBatch4D)}; - ivec3 global_wg_size = graph.logical_limits_of(out); - - const auto packed_dim = graph.packed_dim_of(in); - // The starting offset in a texel where this tensor will start copying from - const auto src_lane_offset = src_offset[packed_dim] & 0x3; - // The starting offset in a texel where this tensor will start copying to - const auto dst_lane_offset = dst_offset[packed_dim] & 0x3; - - // The total packed texels this tensor will be copied from - // The first texel of tensor data in packed dimension will be copied from - // remaining lanes from current source Hence (4 - src_lane_offset) is added - // to tensor size in packed dimension - const auto src_packed_size = utils::div_up_4( - (4 - src_lane_offset) + utils::val_at(-packed_dim, out_sizes)); - - // The total packed texels this tensor will be copied to - // The first texel of tensor data in packed dimension will be copied to - // remaining lanes from previous write Hence (4 - dst_lane_offset) is added - // to tensor size in packed dimension - const auto dst_packed_size = utils::div_up_4( - (4 - dst_lane_offset) + utils::val_at(-packed_dim, in_sizes)); - - // If the starting src offset is not 0, and the total packed texels is - // greater than the source texel range - const bool has_additional_src_work = - src_lane_offset != 0 && src_packed_size > final_range[packed_dim]; - // If the starting dst offset is not 0, and the total packed texels is - // greater than the source texel range - const bool has_additional_dst_work = - dst_lane_offset != 0 && dst_packed_size > final_range[packed_dim]; - - if (has_additional_src_work || has_additional_dst_work) { - global_wg_size[packed_dim]++; // Increase the global work group size in - // packed dimension - final_range[packed_dim]++; // Increase the range in packed dimension - } - - auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {out, vkapi::kRead}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - { - PushConstantDataInfo( - &final_range, sizeof(final_range), sizeof(ivec4)), - PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)), - PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)), - }, - // Specialization Constants - {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_copy_channel_offset_node( - ComputeGraph& graph, - const ValueRef in, - int32_t channel_range, - int32_t src_channel_offset, - int32_t dst_channel_offset, - const ValueRef out) { - // Likely need to prepad these numbers. - const std::vector in_sizes = graph.sizes_of(in); - const std::vector out_sizes = graph.sizes_of(out); - - VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); - VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); - - // NOTE: This function should be able to support 1d and 2d tensors when - // range=1, src_offset=dst_offset=1. - VK_CHECK_COND(graph.dim_of(in) >= 3, "Src dim should be at least 3"); - VK_CHECK_COND(graph.dim_of(out) >= 3, "Dst dim should be at least 3"); - - VK_CHECK_COND( - dim_at(in_sizes) >= src_channel_offset + channel_range, - "Src channel (", - src_channel_offset, - ") and range (", - channel_range, - ") should be less than or equal to input tensor's channel size (", - dim_at(in_sizes), - ")"); - - VK_CHECK_COND( - dim_at(out_sizes) >= dst_channel_offset + channel_range, - "Dst channel (", - dst_channel_offset, - ") and range (", - channel_range, - ") should be less than or equal to input tensor's channel size (", - dim_at(out_sizes), - ")"); - - VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative"); - VK_CHECK_COND( - src_channel_offset >= 0, "Src channel offset must be non-negative"); - VK_CHECK_COND( - dst_channel_offset >= 0, "Dst channel offset must be non-negative"); - - std::string kernel_name = "copy_channel_offset"; - kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - int32_t out_channels = dim_at(out_sizes); - - // Copy one batch at a time. - for (int batch_idx = 0; batch_idx < dim_at(in_sizes); batch_idx++) { - // Mapping the tensor NCHW coordinates into texture XYZ coordinates - int32_t dst_first_z = dst_channel_offset / 4; - int32_t dst_last_z = (dst_channel_offset + channel_range - 1) / 4; - - // We copy the entire width and height dimension. For the channel dimension, - // we use the z-dimension of the global_size to specify the texture range. - // The shader combines the global invocation id and the dst_offset to get - // the actual coordinate. - - const ivec3 dst_offset{ - 0, 0, dst_first_z + batch_idx * utils::div_up_4(out_channels)}; - - const uvec3 global_size{ - utils::safe_downcast(dim_at(in_sizes)), - utils::safe_downcast(dim_at(in_sizes)), - utils::safe_downcast(dst_last_z - dst_first_z + 1)}; - const uvec3 local_size = graph.create_local_wg_size(global_size); - - const utils::ivec4 range_params = { - static_cast(global_size[0]), - static_cast(global_size[1]), - static_cast(global_size[2]), - channel_range}; - - const ivec4 offset_params = { - dst_offset[0], dst_offset[1], dst_offset[2], dst_channel_offset}; - - auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {out, vkapi::kRead}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - {graph.sizes_pc_of(out), - graph.sizes_pc_of(in), - PushConstantDataInfo(&range_params, sizeof(range_params)), - PushConstantDataInfo(&offset_params, sizeof(offset_params)), - PushConstantDataInfo(&src_channel_offset, sizeof(src_channel_offset))}, - // Specialization Constants - {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); - } -} - -void add_copy_offset_node( - ComputeGraph& graph, - ValueRef in, - ValueRef range_ref, - ValueRef src_offset_ref, - ValueRef dst_offset_ref, - ValueRef out) { - ivec3 range = utils::make_ivec3(*graph.get_int_list(range_ref)); - ivec3 src = utils::make_ivec3(*graph.get_int_list(src_offset_ref)); - ivec3 dst = utils::make_ivec3(*graph.get_int_list(dst_offset_ref)); - - ivec4 src_offset = {src[0], src[1], src[2], 0}; - ivec4 dst_offset = {dst[0], dst[1], dst[2], 0}; - - add_copy_offset_node( - graph, in, range, src_offset, dst_offset, out, false, false); -} - -void copy_offset(ComputeGraph& graph, const std::vector& args) { - add_copy_offset_node(graph, args[0], args[1], args[2], args[3], args[4]); -} - -void copy_channel_offset( - ComputeGraph& graph, - const std::vector& args) { - ValueRef in = args[0]; - ValueRef channel_range_ref = args[1]; - ValueRef src_channel_offset_ref = args[2]; - ValueRef dst_channel_offset_ref = args[3]; - ValueRef out = args[4]; - - auto channel_range = graph.extract_scalar(channel_range_ref); - auto src_channel_offset = - graph.extract_scalar(src_channel_offset_ref); - auto dst_channel_offset = - graph.extract_scalar(dst_channel_offset_ref); - - add_copy_channel_offset_node( - graph, in, channel_range, src_channel_offset, dst_channel_offset, out); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(etvk.copy_offset, copy_offset); - VK_REGISTER_OP(etvk.copy_channel_offset, copy_channel_offset); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.h b/backends/vulkan/runtime/graph/ops/impl/Copy.h deleted file mode 100644 index 41956d482d9..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include - -#include - -namespace vkcompute { - -// add_copy_offset_node resumes the vkCmdCopyImage command. It copies the -// texture extents specified by the range, src_offset, and dst_offset (all are -// in texture coordinate (x, y, z) from the input image to the output image. -// src_offset.w and dst_offset.w may contain channel size information. -// -// It is possible to have input and output to point to the same image -// object. But when the source range and destination range overlap, the behavior -// is undefined. -// -// boolean flags calc_out_pos_using_src_chnl and calc_in_pos_using_dst_chnl -// can be used to specify an indexing function in the shader -// If calc_out_pos_using_src_chnl is set to true channel and batch index will be -// calculated based on source channel size and will be used to determine -// destination texel position. -// -// If calc_in_pos_using_dst_chnl is set to truechannel and batch index will be -// calculated based on destination channel size and will be used to determine -// source texel position. -// -// If both are true calc_out_pos_using_src_chnl is picked. If both are false no -// index calculation happens. -void add_copy_offset_node( - ComputeGraph& graph, - const ValueRef in, - const utils::ivec3& range, - const utils::ivec4& src_offset, - const utils::ivec4& dst_offset, - const ValueRef out, - bool calc_out_pos_using_src_chnl, - bool calc_in_pos_using_dst_chnl); - -// add_copy_packed_dim_offset_node behaves similar to add_copy_node, except that -// its used when copying packed dimension, if tensor is width or height packed. -// src_offset.w and dst_offset.w may contain channel size information. -// -// It copies the texture extents specified by the range, src_offset, and -// dst_offset (all are in texture coordinate (x, y, z) from the input image to -// the output image. -void add_copy_packed_dim_offset_node( - ComputeGraph& graph, - const ValueRef in, - const utils::ivec3& range, - const utils::ivec4& src_offset, - const utils::ivec4& dst_offset, - const ValueRef out); - -// add_copy_channel_offset_node behaves similar to add_copy_node, except that it -// works on the channel dimensions of the tensor (up to 4 dimensions in NCHW). -// The range and offset arguments are in the tensor coordinate. It assumes the -// underlying texture is channel-packed. -// -// This function is specialized implementation for copying -// channel packed values. The complication comes from when reading / writing the -// channel dimension on indices that are not aligned to packing, we will need -// be careful about the boundaries. -// -// It achieves the following: -// out[:, dst_channel_offset:dst_channel_offset + channel_range, :, :] = -// in [:, src_channel_offset:src_channel_offset + channel_range, :, :] -void add_copy_channel_offset_node( - ComputeGraph& graph, - const ValueRef in, - int32_t channel_range, - int32_t src_channel_offset, - int32_t dst_channel_offset, - const ValueRef out); - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp deleted file mode 100644 index a217734653d..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ /dev/null @@ -1,843 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include -#include -#include -#include - -namespace vkcompute { - -void resize_dequantize_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); -} - -utils::uvec3 dequantize_per_channel_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)args; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. We need to ensure that we dispatch the correct - // number of workgroups in the Z dimension to cover all batch-channel - // combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -utils::uvec3 dequantize_block_wise_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. We need to ensure that we dispatch the correct - // number of workgroups in the Z dimension to cover all batch-channel - // combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -void add_dequantize_per_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_per_token_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_token"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - int num_tokens = static_cast(graph.sizes_of(scale)[0]); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_per_channel_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& axis, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_channel"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - int axis_val = static_cast(graph.get_int(axis)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - // Normalize axis and convert from NCHW to WHCN using utility functions - const auto input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - // Normalize axis to handle negative indices - axis_val = normalize(axis_val, ndim); - - // Convert from NCHW axis to WHCN axis for shader (vulkan representation) - int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); - - int num_channels; - if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension dequantization in 4D tensors, pass the actual number - // of channels so the shader can correctly unfold the batch-channel folding - num_channels = static_cast(input_sizes[1]); // Channel dimension - } else { - num_channels = static_cast(input_sizes[axis_val]); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - dequantize_per_channel_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_block_wise_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& block_size, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // Convert dimensions to WHCN order for shader - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) - utils::ivec4 num_blocks_vec = { - tensor_size_whcn[0] / block_size_vec[0], - tensor_size_whcn[1] / block_size_vec[1], - tensor_size_whcn[2] / block_size_vec[2], - tensor_size_whcn[3] / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - dequantize_block_wise_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void dequantize_per_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - add_dequantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void dequantize_per_token_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_tokens - // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors - // (size [num_tokens, 1]) - VK_CHECK_COND(scale_numel == num_tokens); - VK_CHECK_COND(zero_point_numel == num_tokens); - - add_dequantize_per_token_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void dequantize_per_channel_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef axis = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Normalize axis - int axis_val = static_cast(graph.get_int(axis)); - const auto input_sizes = graph.sizes_of(input); - int ndim = graph.dim_of(input); - if (axis_val < 0) { - axis_val += ndim; - } - - // Verify axis is valid - VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); - - // Get number of channels along the specified axis - int64_t num_channels = input_sizes[axis_val]; - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_channels - VK_CHECK_COND(scale_numel == num_channels); - VK_CHECK_COND(zero_point_numel == num_channels); - - add_dequantize_per_channel_node( - graph, input, scale, zero_point, axis, quant_min, quant_max, output); -} - -void dequantize_affine_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef block_size = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef input_dtype = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - (void)input_dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Verify block_size is valid (each dimension must divide evenly into input - // size) - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - for (size_t i = 0; i < input_sizes.size(); i++) { - if ((*block_size_list)[i] > 1) { - VK_CHECK_COND( - input_sizes[i] % (*block_size_list)[i] == 0, - "Input size at dimension ", - i, - " (", - input_sizes[i], - ") must be divisible by block_size at dimension ", - i, - " (", - (*block_size_list)[i], - ")"); - } - } - - add_dequantize_block_wise_node( - graph, - input, - block_size, - scale, - zero_point, - quant_min, - quant_max, - output); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_tensor.tensor, - dequantize_per_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_token.default, - dequantize_per_token_impl); - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_channel.default, - dequantize_per_channel_impl); - - // TorchAO affine dequantization operators - VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Gather.cpp b/backends/vulkan/runtime/graph/ops/impl/Gather.cpp new file mode 100644 index 00000000000..584a8d0437b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Gather.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +using utils::GPUMemoryLayout; +using utils::StorageType; + +void resize_gather_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef index = args.at(1).refs.at(1); + + // Output shape is the same as index shape + std::vector out_sizes = graph->sizes_of(index); + graph->virtual_resize(out, out_sizes); +} + +void add_gather_node( + ComputeGraph& graph, + const ValueRef input, + const int64_t dim, + const ValueRef index, + const ValueRef out) { + std::string kernel_name = "gather"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(input), graph.meta_ubo(index)}; + + const int64_t dim_whcn = graph.dim_of(input) - dim - 1; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{input, index}, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {static_cast(dim_whcn)}, + // Resize Args + {}, + // Resizing Logic + resize_gather_node)); +} + +void gather(ComputeGraph& graph, const std::vector& args) { + ValueRef input = args[0]; + ValueRef dim_ref = args[1]; + ValueRef index = args[2]; + ValueRef out = args[4]; + + int64_t dim = graph.extract_scalar(dim_ref); + + add_gather_node(graph, input, dim, index, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.gather.default, gather); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp deleted file mode 100644 index 88f77261f4f..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ /dev/null @@ -1,836 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include -#include - -#include - -namespace vkcompute { - -void resize_quantize_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); -} - -utils::uvec3 quantize_per_channel_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)args; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel quantization along the batch axis, - // we need to ensure that we dispatch the correct number of workgroups in the - // Z dimension to cover all batch-channel combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -utils::uvec3 quantize_block_wise_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel quantization along the batch axis, - // we need to ensure that we dispatch the correct number of workgroups in the - // Z dimension to cover all batch-channel combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -void add_quantize_per_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_per_token_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_token"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - int num_tokens = static_cast(graph.sizes_of(scale)[0]); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_per_channel_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& axis, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_channel"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - int axis_val = static_cast(graph.get_int(axis)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - // Normalize axis and convert from NCHW to WHCN using utility functions - const auto input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - // Normalize axis to handle negative indices - axis_val = normalize(axis_val, ndim); - - // Convert from NCHW axis to WHCN axis for shader (vulkan representation) - int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); - - int num_channels; - if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension quantization in 4D tensors, pass the actual number of - // channels so the shader can correctly unfold the batch-channel folding - num_channels = static_cast(input_sizes[1]); // Channel dimension - } else { - num_channels = static_cast(input_sizes[axis_val]); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - quantize_per_channel_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_block_wise_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& block_size, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // Convert PyTorch dimensions to WHCN order for shader - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) - utils::ivec4 num_blocks_vec = { - tensor_size_whcn[0] / block_size_vec[0], - tensor_size_whcn[1] / block_size_vec[1], - tensor_size_whcn[2] / block_size_vec[2], - tensor_size_whcn[3] / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - quantize_block_wise_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void quantize_per_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - add_quantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void quantize_per_token_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_tokens - // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors - // (size [num_tokens, 1]) - VK_CHECK_COND(scale_numel == num_tokens); - VK_CHECK_COND(zero_point_numel == num_tokens); - - add_quantize_per_token_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void quantize_per_channel_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef axis = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Normalize axis - int axis_val = static_cast(graph.get_int(axis)); - const auto input_sizes = graph.sizes_of(input); - int64_t ndim = graph.dim_of(input); - if (axis_val < 0) { - axis_val += ndim; - } - - // Verify axis is valid - VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); - - // Get number of channels along the specified axis - int64_t num_channels = input_sizes[axis_val]; - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_channels - VK_CHECK_COND(scale_numel == num_channels); - VK_CHECK_COND(zero_point_numel == num_channels); - - add_quantize_per_channel_node( - graph, input, scale, zero_point, axis, quant_min, quant_max, output); -} - -void quantize_affine_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef block_size = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Verify block_size is valid (each dimension must divide evenly into input - // size) - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - for (size_t i = 0; i < input_sizes.size(); i++) { - if ((*block_size_list)[i] > 1) { - VK_CHECK_COND( - input_sizes[i] % (*block_size_list)[i] == 0, - "Input size at dimension ", - i, - " (", - input_sizes[i], - ") must be divisible by block_size at dimension ", - i, - " (", - (*block_size_list)[i], - ")"); - } - } - - add_quantize_block_wise_node( - graph, - input, - block_size, - scale, - zero_point, - quant_min, - quant_max, - output); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.quantize_per_tensor.tensor, - quantize_per_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); - VK_REGISTER_OP( - quantized_decomposed.quantize_per_channel.default, - quantize_per_channel_impl); - - // TorchAO affine quantization operators - VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp new file mode 100644 index 00000000000..ee8f8a1afb4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp @@ -0,0 +1,450 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +// +// General utilities +// + +bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input) { + return graph->size_at(-2, fp_input) == 1; +} + +// +// Dispatch utilities (Linear) +// + +std::tuple get_quantized_input_num_blocks( + ComputeGraph& graph, + const ValueRef input) { + std::vector input_sizes = graph.sizes_of(input); + const int64_t ndim = graph.dim_of(input); + + const int64_t M = input_sizes.at(ndim - 2); + const int64_t K = input_sizes.at(ndim - 1); + + const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); + const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); + + return std::make_tuple(num_blocks_M, num_blocks_K); +} + +utils::uvec3 quantize_and_pack_4h4w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input = args.at(1).refs.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, input); + + return { + utils::safe_downcast(num_blocks_K), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +vkapi::ShaderInfo pick_quantize_and_pack_4h4w_with_group_sums_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int_input = args.at(0).refs.at(0); + const ValueRef fp_input = args.at(1).refs.at(0); + const ValueRef group_size = resize_args.at(0); + + const int64_t group_size_val = graph->extract_scalar(group_size); + + std::string shader_name = "quantize_and_pack_4h4w_with_group_sums"; + if (group_size_val >= 128) { + shader_name += "_o2w32"; + } else { + shader_name += "_o4w16"; + } + + add_storage_type_suffix( + shader_name, graph->storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph->storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph->dtype_of(fp_input)); + + return VK_KERNEL_FROM_STR(shader_name); +} + +utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_input = args.at(1).refs.at(0); + // For gemv cases, skip the quantize and pack input step in favor of computing + // the quantized linear as a weight only quantized linear operation. The + // rationale for this is that gemv is a memory bound operation and may not + // necessarily benefit from quantizing the input and computing with integer + // accumulation. + if (is_gemv(graph, fp_input)) { + return {0u, 0u, 0u}; + } + + const ValueRef group_size = resize_args.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, fp_input); + + const int64_t group_size_val = graph->extract_scalar(group_size); + const int64_t blocks_per_group = group_size_val / 4; + + const int64_t num_groups = num_blocks_K / blocks_per_group; + + return { + utils::safe_downcast(num_groups), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef fp_input = args.at(1).refs.at(0); + // For gemv, skip the quantize input step since the quantized linear is + // computed as a weight only quantized linear operation. + if (is_gemv(graph, fp_input)) { + return {1u, 1u, 1u}; + } + + uint32_t groups_per_wg = 2u; + uint32_t workers_per_group = 32u; + + if (shader.kernel_name.find("o4w16") != std::string::npos) { + groups_per_wg = 4u; + workers_per_group = 16u; + } + + return {groups_per_wg, 1u, workers_per_group}; +} + +// +// Dispatch logic (Linear) +// + +void add_quantize_and_pack_4h4w_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale_data); + int32_t zp = graph.extract_scalar(input_zp_data); + + std::string shader_name = "quantize_and_pack_4h4w_per_tensor"; + add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph.dtype_of(fp_input)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + quantize_and_pack_4h4w_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {})); +} + +void add_quantize_and_pack_4h4w_with_group_sums_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef int_input_sums, + const ValueRef packed_input_scales, + const ValueRef packed_input_zps, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerChannel); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + const int32_t group_size_val = graph.extract_scalar(group_size); + const int32_t blocks_per_group = utils::div_up(group_size_val, int32_t(4)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_quantize_and_pack_4h4w_with_group_sums_shader, + pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size, + pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size, + // Inputs and Outputs + {{{packed_int_input, int_input_sums}, vkapi::kWrite}, + {{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {blocks_per_group}, + // Resize args + {group_size})); +} + +// +// Dispatch utilities (Conv2d) +// + +utils::uvec3 pick_quantize_and_pack_4w4c_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_input = args.at(1).refs.at(0); + + const uint32_t W = graph->size_at(-1, fp_input); + const uint32_t H = graph->size_at(-2, fp_input); + const uint32_t C = graph->size_at(-3, fp_input); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +utils::uvec3 pick_unpack_4w4c_and_dequantize_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, fp_output); + const uint32_t H = graph->size_at(-2, fp_output); + const uint32_t C = graph->size_at(-3, fp_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +// +// Dispatch logic (Conv2d) +// + +void add_quantize_and_pack_4w4c_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_input) { + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "quantize_and_pack_4w4c_per_tensor"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_input)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_quantize_and_pack_4w4c_global_wg_size, + pick_wc_square_wg_size, + // Inputs and Outputs + {{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_unpack_4w4c_and_dequantize_node( + ComputeGraph& graph, + const ValueRef packed_int8_output, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output) { + float scale = graph.extract_scalar(output_scale); + int32_t zp = graph.extract_scalar(output_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "unpack_4w4c_and_dequantize_per_tensor"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_output)); + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_output)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_output)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_unpack_4w4c_and_dequantize_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{fp_output, vkapi::kWrite}, {packed_int8_output, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// Operator Entrypoints +// + +void quantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int32_t arg_idx = 0; + const ValueRef fp_input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + (void)quant_min; + const ValueRef quant_max = args[arg_idx++]; + (void)quant_max; + const ValueRef dtype = args[arg_idx++]; + (void)dtype; + + const ValueRef int8_output = args[arg_idx++]; + + VK_CHECK_COND( + graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input, scale, zero_point, int8_output); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int32_t arg_idx = 0; + const ValueRef int8_input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + (void)quant_min; + const ValueRef quant_max = args[arg_idx++]; + (void)quant_max; + const ValueRef dtype = args[arg_idx++]; + (void)dtype; + const ValueRef output_dtype = args[arg_idx++]; + (void)output_dtype; + + const ValueRef fp_output = args[arg_idx++]; + + VK_CHECK_COND( + graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C); + + add_unpack_4w4c_and_dequantize_node( + graph, int8_input, scale, zero_point, fp_output); +} + +void qdq8ta_conv2d_input( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input, scale, zero_point, packed_int8_input); + + add_unpack_4w4c_and_dequantize_node( + graph, packed_int8_input, scale, zero_point, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h new file mode 100644 index 00000000000..96e9cc7c1d3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +// +// General utils +// + +bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input); + +// +// Quantize, Dequantize for Linear/Matmul +// + +void add_quantize_and_pack_4h4w_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef packed_int_input, + const ValueRef group_size); + +void add_quantize_and_pack_4h4w_with_group_sums_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef int_input_sums, + const ValueRef packed_input_scales, + const ValueRef packed_input_zps, + const ValueRef packed_int_input, + const ValueRef group_size); + +// +// Quantize, Dequantize for Convolution +// + +void add_quantize_and_pack_4w4c_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_input); + +void add_unpack_4w4c_and_dequantize_node( + ComputeGraph& graph, + const ValueRef packed_int8_output, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp index 4b359f12700..99b5880c2eb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include @@ -178,10 +178,10 @@ void add_q8ta_q8ta_q8to_test( utils::kBuffer, utils::kPackedInt8_4W4C); - add_quantize_and_pack_q8ta_conv2d_input_node( + add_quantize_and_pack_4w4c_node( graph, fp_input_a, input_a_scale, input_a_zp, packed_int8_input_a); - add_quantize_and_pack_q8ta_conv2d_input_node( + add_quantize_and_pack_4w4c_node( graph, fp_input_b, input_b_scale, input_b_zp, packed_int8_input_b); std::vector add_args = { @@ -198,7 +198,7 @@ void add_q8ta_q8ta_q8to_test( add_q8ta_q8ta_q8to(graph, add_args); - add_unpack_and_dequantize_q8ta_conv2d_output_node( + add_unpack_4w4c_and_dequantize_node( graph, packed_int8_output, output_scale, output_zp, fp_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 775e4534cfb..a0f5763df06 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -274,40 +275,6 @@ std::vector calculate_output_im2col_sizes( // Shader dispatch utilities // -utils::uvec3 pick_quantize_and_pack_conv2d_input_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef fp_input = args.at(1).refs.at(0); - - const uint32_t W = graph->size_at(-1, fp_input); - const uint32_t H = graph->size_at(-2, fp_input); - const uint32_t C = graph->size_at(-3, fp_input); - - const uint32_t W4 = utils::div_up_4(W); - const uint32_t C4 = utils::div_up_4(C); - - return {W4, H, C4}; -} - -utils::uvec3 pick_unpack_and_dequantize_conv2d_output_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef fp_output = args.at(0).refs.at(0); - - const uint32_t W = graph->size_at(-1, fp_output); - const uint32_t H = graph->size_at(-2, fp_output); - const uint32_t C = graph->size_at(-3, fp_output); - - const uint32_t W4 = utils::div_up_4(W); - const uint32_t C4 = utils::div_up_4(C); - - return {W4, H, C4}; -} - utils::uvec3 im2col_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -704,94 +671,6 @@ void add_input_im2col_packed_int8_node( nullptr)); } -void add_quantize_and_pack_q8ta_conv2d_input_node( - ComputeGraph& graph, - const ValueRef fp_input, - const ValueRef input_scale, - const ValueRef input_zp, - const ValueRef packed_int8_input) { - float inv_scale = 1.0f / graph.extract_scalar(input_scale); - int32_t zp = graph.extract_scalar(input_zp); - - // Get shader for quantized conv2d linear tiled - std::string kernel_name = "quantize_and_pack_q8ta_conv2d_input"; - add_storage_type_suffix( - kernel_name, graph.storage_type_of(packed_int8_input)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_input)); - add_dtype_suffix(kernel_name, graph.dtype_of(fp_input)); - - vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - - vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; - - std::vector push_constants = { - PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), - PushConstantDataInfo(&zp, sizeof(zp)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - pick_quantize_and_pack_conv2d_input_global_wg_size, - pick_wc_square_wg_size, - // Inputs and Outputs - {{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, - // Shader params buffers - param_buffers, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize args - {}, - // Resizing Logic - nullptr)); -} - -void add_unpack_and_dequantize_q8ta_conv2d_output_node( - ComputeGraph& graph, - const ValueRef packed_int8_output, - const ValueRef output_scale, - const ValueRef output_zp, - const ValueRef fp_output) { - float scale = graph.extract_scalar(output_scale); - int32_t zp = graph.extract_scalar(output_zp); - - // Get shader for quantized conv2d linear tiled - std::string kernel_name = "unpack_and_dequantize_q8ta_conv2d_output"; - add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_output)); - add_storage_type_suffix( - kernel_name, graph.storage_type_of(packed_int8_output)); - add_dtype_suffix(kernel_name, graph.dtype_of(fp_output)); - - vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - - vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_output)}; - - std::vector push_constants = { - PushConstantDataInfo(&scale, sizeof(scale)), - PushConstantDataInfo(&zp, sizeof(zp)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - pick_unpack_and_dequantize_conv2d_output_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{fp_output, vkapi::kWrite}, {packed_int8_output, vkapi::kRead}}, - // Shader params buffers - param_buffers, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize args - {}, - // Resizing Logic - nullptr)); -} - void add_quantize_and_pack_im2col_node( ComputeGraph& graph, const ValueRef input_image, @@ -1607,59 +1486,6 @@ void conv2d_q8ta_q8csw_q8to( packed_int8_output); } -// -// Quantize and dequantize operators -// - -void quantize_q8ta_for_conv2d( - ComputeGraph& graph, - const std::vector& args) { - int32_t idx = 0; - const ValueRef fp_input = args.at(idx++); - const ValueRef scale = args.at(idx++); - const ValueRef zero_point = args.at(idx++); - const ValueRef packed_int8_input = args.at(idx++); - - add_quantize_and_pack_q8ta_conv2d_input_node( - graph, fp_input, scale, zero_point, packed_int8_input); -} - -void dequantize_q8to_from_conv2d( - ComputeGraph& graph, - const std::vector& args) { - int32_t idx = 0; - const ValueRef packed_int8_output = args.at(idx++); - const ValueRef scale = args.at(idx++); - const ValueRef zero_point = args.at(idx++); - const ValueRef fp_output = args.at(idx++); - - add_unpack_and_dequantize_q8ta_conv2d_output_node( - graph, packed_int8_output, scale, zero_point, fp_output); -} - -void qdq8ta_conv2d_input( - ComputeGraph& graph, - const std::vector& args) { - int32_t idx = 0; - const ValueRef fp_input = args.at(idx++); - const ValueRef scale = args.at(idx++); - const ValueRef zero_point = args.at(idx++); - const ValueRef fp_output = args.at(idx++); - - TmpTensor packed_int8_input( - &graph, - graph.sizes_of(fp_input), - vkapi::kInt8x4, - utils::kBuffer, - utils::kPackedInt8_4W4C); - - add_quantize_and_pack_q8ta_conv2d_input_node( - graph, fp_input, scale, zero_point, packed_int8_input); - - add_unpack_and_dequantize_q8ta_conv2d_output_node( - graph, packed_int8_input, scale, zero_point, fp_output); -} - // // Test operators // @@ -1698,7 +1524,7 @@ void conv2d_q8ta_q8csw_q8to_test( utils::kBuffer, utils::kPackedInt8_4W4C); - add_quantize_and_pack_q8ta_conv2d_input_node( + add_quantize_and_pack_4w4c_node( graph, fp_input, input_scale, input_zp, packed_int8_input); std::vector conv2d_args = { @@ -1720,19 +1546,14 @@ void conv2d_q8ta_q8csw_q8to_test( conv2d_q8ta_q8csw_q8to(graph, conv2d_args); - add_unpack_and_dequantize_q8ta_conv2d_output_node( + add_unpack_4w4c_and_dequantize_node( graph, packed_int8_output, output_scale, output_zp, fp_output); } REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw); - VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); VK_REGISTER_OP(etvk.conv2d_q8ta_q8csw_q8to.test, conv2d_q8ta_q8csw_q8to_test); - VK_REGISTER_OP( - et_vk.quantize_q8ta_for_conv2d.default, quantize_q8ta_for_conv2d); - VK_REGISTER_OP( - et_vk.dequantize_q8to_from_conv2d.default, dequantize_q8to_from_conv2d); VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw_q8to.default, conv2d_q8ta_q8csw_q8to); VK_REGISTER_OP( et_vk.conv2d_q8ta_q8csw_q8to_dw.default, conv2d_q8ta_q8csw_q8to); diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h index 33474cee47b..c3ea15bc318 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h @@ -12,31 +12,7 @@ namespace vkcompute { -// -// Quantize and dequantize functions for conv2d that can be reused by other -// operations -// - -/** - * Add a dispatch node to quantize a floating-point input tensor to a packed - * int8 tensor for use in quantized operations. - */ -void add_quantize_and_pack_q8ta_conv2d_input_node( - ComputeGraph& graph, - const ValueRef fp_input, - const ValueRef input_scale, - const ValueRef input_zp, - const ValueRef packed_int8_input); - -/** - * Add a dispatch node to unpack and dequantize a packed int8 output tensor back - * to a floating-point tensor. - */ -void add_unpack_and_dequantize_q8ta_conv2d_output_node( - ComputeGraph& graph, - const ValueRef packed_int8_output, - const ValueRef output_scale, - const ValueRef output_zp, - const ValueRef fp_output); +// This header is intentionally empty as all quantize/dequantize functions +// have been moved to QuantizeDequantize.h } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 97566038501..7a42d463f2a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -19,10 +20,6 @@ namespace vkcompute { // Shader dispatch utilities // -bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input) { - return graph->size_at(-2, fp_input) == 1; -} - void resize_linear_qw_node( ComputeGraph* graph, const std::vector& args, @@ -105,120 +102,6 @@ utils::uvec3 quantized_linear_local_wg_size( } } -std::tuple get_quantized_input_num_blocks( - ComputeGraph& graph, - const ValueRef input) { - std::vector input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - const int64_t M = input_sizes.at(ndim - 2); - const int64_t K = input_sizes.at(ndim - 1); - - const int64_t num_blocks_M = utils::div_up(M, int64_t(4)); - const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); - - return std::make_tuple(num_blocks_M, num_blocks_K); -} - -utils::uvec3 quant_pack_input_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef input = args.at(1).refs.at(0); - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(*graph, input); - - return { - utils::safe_downcast(num_blocks_K), - utils::safe_downcast(num_blocks_M), - 1u}; -} - -vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef packed_int_input = args.at(0).refs.at(0); - const ValueRef fp_input = args.at(1).refs.at(0); - const ValueRef group_size = resize_args.at(0); - - const int64_t group_size_val = graph->extract_scalar(group_size); - - std::string shader_name = "quantize_and_pack_linear_input_with_sums"; - if (group_size_val >= 128) { - shader_name += "_o2w32"; - } else { - shader_name += "_o4w16"; - } - - add_storage_type_suffix( - shader_name, graph->storage_type_of(packed_int_input)); - add_storage_type_suffix(shader_name, graph->storage_type_of(fp_input)); - add_dtype_suffix(shader_name, graph->dtype_of(fp_input)); - - return VK_KERNEL_FROM_STR(shader_name); -} - -utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - const ValueRef fp_input = args.at(1).refs.at(0); - // For gemv cases, skip the quantize and pack input step in favor of computing - // the quantized linear as a weight only quantized linear operation. The - // rationale for this is that gemv is a memory bound operation and may not - // necessarily benefit from quantizing the input and computing with integer - // accumulation. - if (is_gemv(graph, fp_input)) { - return {0u, 0u, 0u}; - } - - const ValueRef group_size = resize_args.at(0); - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(*graph, fp_input); - - const int64_t group_size_val = graph->extract_scalar(group_size); - const int64_t blocks_per_group = group_size_val / 4; - - const int64_t num_groups = num_blocks_K / blocks_per_group; - - return { - utils::safe_downcast(num_groups), - utils::safe_downcast(num_blocks_M), - 1u}; -} - -utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef fp_input = args.at(1).refs.at(0); - // For gemv, skip the quantize input step since the quantized linear is - // computed as a weight only quantized linear operation. - if (is_gemv(graph, fp_input)) { - return {1u, 1u, 1u}; - } - - uint32_t groups_per_wg = 2u; - uint32_t workers_per_group = 32u; - - if (shader.kernel_name.find("o4w16") != std::string::npos) { - groups_per_wg = 4u; - workers_per_group = 16u; - } - - return {groups_per_wg, 1u, workers_per_group}; -} - vkapi::ShaderInfo pick_linear_qw_shader( ComputeGraph* graph, const std::vector& args, @@ -421,7 +304,7 @@ ValueRef prepack_quantized_linear_weight( /* * Shader dispatch for linear with quantized weight but fp activations. */ -DynamicDispatchNode make_linear_qw_node( +void add_linear_qw_node( ComputeGraph& graph, const QuantizationConfig& weight_quant_config, const ValueRef fp_input, @@ -458,7 +341,7 @@ DynamicDispatchNode make_linear_qw_node( const ValueRef is_4bit_flag = weight_quant_config.nbits == 4 ? group_size : kDummyValueRef; - return DynamicDispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_linear_qw_shader, quantized_linear_global_wg_size, @@ -476,98 +359,10 @@ DynamicDispatchNode make_linear_qw_node( // Resize args {is_4bit_flag, weight_data}, // Resizing Logic - resize_linear_qw_node); -} - -DynamicDispatchNode make_quantize_and_pack_linear_input_node( - ComputeGraph& graph, - const QuantizationConfig& input_quant_config, - const ValueRef fp_input, - const ValueRef packed_input_scale, - const ValueRef packed_input_zp, - const ValueRef input_scale_data, - const ValueRef input_zp_data, - const ValueRef packed_int_input, - const ValueRef group_size) { - // Only certain quantization types supported at the moment - VK_CHECK_COND(input_quant_config.granularity == kPerTensor); - - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(graph, fp_input); - - float inv_scale = 1.0f / graph.extract_scalar(input_scale_data); - int32_t zp = graph.extract_scalar(input_zp_data); - - std::string shader_name = "quantize_and_pack_linear_input_per_tensor"; - add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input)); - add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input)); - add_dtype_suffix(shader_name, graph.dtype_of(fp_input)); - - vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; - - std::vector push_constants = { - PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), - PushConstantDataInfo(&zp, sizeof(zp)), - }; - - return DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(shader_name), - quant_pack_input_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, - // Shader params buffers - param_buffers, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize args - {}); + resize_linear_qw_node)); } -DynamicDispatchNode make_quantize_and_pack_linear_input_with_sums_node( - ComputeGraph& graph, - const QuantizationConfig& input_quant_config, - const ValueRef fp_input, - const ValueRef int_input_sums, - const ValueRef packed_input_scales, - const ValueRef packed_input_zps, - const ValueRef packed_int_input, - const ValueRef group_size) { - // Only certain quantization types supported at the moment - VK_CHECK_COND(input_quant_config.granularity == kPerChannel); - - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(graph, fp_input); - - vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; - - const int32_t group_size_val = graph.extract_scalar(group_size); - const int32_t blocks_per_group = utils::div_up(group_size_val, int32_t(4)); - - return DynamicDispatchNode( - graph, - pick_quantize_and_pack_input_with_sums_shader, - pick_quantize_and_pack_input_with_sums_global_wg_size, - pick_quantize_and_pack_input_with_sums_local_wg_size, - // Inputs and Outputs - {{{packed_int_input, int_input_sums}, vkapi::kWrite}, - {{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead}}, - // Shader params buffers - param_buffers, - // Push Constants - {}, - // Specialization Constants - {blocks_per_group}, - // Resize args - {group_size}); -} - -DynamicDispatchNode make_linear_qa_qw_node( +void add_linear_qa_qw_node( ComputeGraph& graph, const QuantizationConfig& input_quant_config, const QuantizationConfig& weight_quant_config, @@ -615,8 +410,7 @@ DynamicDispatchNode make_linear_qa_qw_node( apply_bias = 0; } - // Add the compute node - return DynamicDispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), quantized_linear_global_wg_size, @@ -638,10 +432,10 @@ DynamicDispatchNode make_linear_qa_qw_node( // Resize args {fp_input}, // Resizing Logic - nullptr); + nullptr)); } -DynamicDispatchNode make_linear_dqa_qw_node( +void add_linear_dqa_qw_node( ComputeGraph& graph, const QuantizationConfig& input_quant_config, const QuantizationConfig& weight_quant_config, @@ -685,8 +479,7 @@ DynamicDispatchNode make_linear_dqa_qw_node( const ValueRef is_4bit_flag = weight_quant_config.nbits == 4 ? group_size : kDummyValueRef; - // Add the compute node - return DynamicDispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_linear_dqa_qw_shader, quantized_linear_global_wg_size, @@ -712,7 +505,7 @@ DynamicDispatchNode make_linear_dqa_qw_node( // Resize args {is_4bit_flag, weight_data}, // Resizing Logic - resize_linear_qw_node); + resize_linear_qw_node)); } // @@ -770,7 +563,7 @@ void quantized_linear_impl( // 2. Input is not quantized if (!graph.can_use_int8_dot_product() || input_quant_config.granularity == kNoQuantization) { - DynamicDispatchNode linear_qw_node(make_linear_qw_node( + add_linear_qw_node( graph, weight_quant_config, fp_input, @@ -781,9 +574,8 @@ void quantized_linear_impl( group_size, bias_data, packed_bias, - output)); + output); - graph.execute_nodes().emplace_back(new DynamicDispatchNode(linear_qw_node)); return; } // Otherwise, use input and weight quantized linear computed with integer @@ -815,22 +607,18 @@ void quantized_linear_impl( // Non dynamically quantized input case if (!input_quant_config.is_dynamic) { - DynamicDispatchNode quantize_and_pack_linear_node( - make_quantize_and_pack_linear_input_node( - graph, - input_quant_config, - fp_input, - packed_input_scale, - packed_input_zp, - input_scale, - input_zp, - packed_int_input, - group_size)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(quantize_and_pack_linear_node)); - - DynamicDispatchNode linear_qa_qw_node(make_linear_qa_qw_node( + add_quantize_and_pack_4h4w_node( + graph, + input_quant_config, + fp_input, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + packed_int_input, + group_size); + + add_linear_qa_qw_node( graph, input_quant_config, weight_quant_config, @@ -847,10 +635,7 @@ void quantized_linear_impl( group_size, bias_data, packed_bias, - output)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(linear_qa_qw_node)); + output); return; } @@ -871,21 +656,17 @@ void quantized_linear_impl( utils::kBuffer, utils::kWidthPacked); - DynamicDispatchNode quantize_and_pack_input_with_sums_node( - make_quantize_and_pack_linear_input_with_sums_node( - graph, - input_quant_config, - fp_input, - int_input_sums, - packed_input_scale, - packed_input_zp, - packed_int_input, - group_size)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(quantize_and_pack_input_with_sums_node)); - - DynamicDispatchNode linear_dqa_qw_node(make_linear_dqa_qw_node( + add_quantize_and_pack_4h4w_with_group_sums_node( + graph, + input_quant_config, + fp_input, + int_input_sums, + packed_input_scale, + packed_input_zp, + packed_int_input, + group_size); + + add_linear_dqa_qw_node( graph, input_quant_config, weight_quant_config, @@ -903,10 +684,7 @@ void quantized_linear_impl( group_size, bias_data, packed_bias, - output)); - - graph.execute_nodes().emplace_back( - new DynamicDispatchNode(linear_dqa_qw_node)); + output); } void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index 72c1637a2c9..2b42c0bd150 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -14,8 +14,6 @@ #include #include -#include - namespace vkcompute { namespace { diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index fcc8fe4b265..e1914f350b7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size( const ValueRef xq_out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); - global_wg_size[0] /= 2; + // Head dim texel size + const uint32_t D4 = utils::div_up_4(graph->size_at(-1, xq_out)); + // Divide by 2 since each invocation computes 2 output locations + const uint32_t D8 = utils::div_up(D4, uint32_t(2)); - return global_wg_size; + // Number of query heads + const uint32_t QH = graph->size_at(-2, xq_out); + // Input tokens sequence length + const uint32_t S = graph->size_at(-3, xq_out); + + return {D8, QH, S}; } void add_rotary_embedding_node( @@ -73,8 +80,14 @@ void add_rotary_embedding_node( VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); std::string kernel_name = "rotary_embedding"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out)); add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(xq_out), + graph.meta_ubo(xk_out), + graph.meta_ubo(freqs_cos)}; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -84,7 +97,7 @@ void add_rotary_embedding_node( {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, // Parameter buffers - {graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, + param_ubos, // Push Constants {}, // Specialization Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 4eed8b82834..d28d2c90fcb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND(graph.val_is_none(attn_mask)); const int64_t num_q_heads = graph.size_at(-2, q_projected); - const int64_t max_seq_len = graph.size_at(-3, q_projected); - + int64_t max_seq_len = graph.size_at(-3, q_projected); const int64_t max_context_len = graph.size_at(-3, k_cache); + const utils::StorageType attn_weights_storage = + graph.storage_type_of(q_projected); + + // If using buffer storage for attn weights, we need to ensure that the buffer + // numel limit is not exceeded. If needed, manually adjust max_seq_len based + // on the buffer numel limit. + if (attn_weights_storage == utils::kBuffer) { + const int64_t max_buffer_numel = graph.max_buffer_numel(); + if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) { + // Compute the maximum possible value for max_seq_len that will hit + // the buffer numel limit. + max_seq_len = max_buffer_numel / (num_q_heads * max_context_len); + // Adjust down to the nearest multiple of 4 to make sure the limit is + // not hit. + if (max_seq_len % 4 != 0) { + max_seq_len = (max_seq_len / 4) * 4; + } else { + max_seq_len -= 4; + } + } + } + std::vector attn_weight_full_sizes = { 1, // batch num_q_heads, @@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); TmpTensor attn_weights_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); add_sdpa_compute_attn_weights_node( @@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl( (void)sequence_len; - utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index f87af08ee69..4e62ae8806d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -8,134 +8,131 @@ #include -#include +#include +#include #include -#include #include + #include -namespace vkcompute { +#include -void add_split_with_sizes_default_node( - ComputeGraph& graph, - ValueRef in, - const std::vector& split_sizes, - int64_t dim, - ValueRef out_list_ref) { - const ValueListPtr out_list = graph.get_value_list(out_list_ref); +namespace vkcompute { - const int64_t input_ndim = graph.dim_of(in); +using utils::GPUMemoryLayout; +using utils::StorageType; + +void resize_split_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef input = args.at(0).refs.at(0); + const ValueRef split_sizes_ref = args.at(1).refs.at(0); + const ValueRef dim_ref = args.at(2).refs.at(0); + const ValueRef out_list_ref = args.at(3).refs.at(0); + + const ValueListPtr out_list = graph->get_value_list(out_list_ref); + const std::vector split_sizes = + *(graph->get_int_list(split_sizes_ref)); + const int64_t dim = graph->extract_scalar(dim_ref); + + const int64_t input_ndim = graph->dim_of(input); const DimIndex dim_index = dim < 0 ? static_cast(dim) : static_cast(dim - input_ndim); - VK_CHECK_COND(out_list->size() == split_sizes.size()); + std::vector input_sizes = graph->sizes_of(input); for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) { const int64_t split_size = split_sizes.at(split_idx); const ValueRef out_ref = out_list->at(split_idx); - VK_CHECK_COND(dim_at(graph.sizes_of(out_ref), dim_index) == split_size); - } - - const auto packed_dim = graph.packed_dim_of(in); - const auto packed_dim_index = static_cast(kWidth4D - packed_dim); + std::vector out_sizes = input_sizes; + out_sizes.at(dim_index) = split_size; - // Index of dimension to be concatenated in (w, h, c * b) coordinate system - const auto dim_xyz_index = std::min(2, -dim_index - 1); - - utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); - utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); - - const bool is_splitting_channel = (dim_index == kChannel4D); - - // if splitting channels - if (is_splitting_channel) { - // set source offset w as channel size of the input tensor - src_offset[3] = dim_at(graph.sizes_of(in), kChannel4D); + graph->virtual_resize(out_ref, out_sizes); } +} - for (ValueRef out_ref : *out_list) { - // Doesn't need to use split_size since we have already verified that the - // output tensor's size matches with the split_size. - const auto out_channel_size = dim_at(graph.sizes_of(out_ref), kChannel4D); - const utils::ivec3 range = graph.logical_limits_of(out_ref); - - if (dim_index == packed_dim_index) { - // if splitting channels, use add_copy_channel_offset_node function as - // add_copy_packed_dim_offset_node does not support channel packing - if (is_splitting_channel) { - add_copy_channel_offset_node( - graph, in, out_channel_size, src_offset[2], dst_offset[2], out_ref); - src_offset[dim_xyz_index] += out_channel_size; - } else { - // dst_offset[3] is not used now but will be used in the future when - // add_copy_packed_dim_offset_node will support channel packing - // - // set destination offset w as channel size of the output tensor if - // splitting channel - dst_offset[3] = is_splitting_channel ? out_channel_size : 0; - add_copy_packed_dim_offset_node( - graph, in, range, src_offset, dst_offset, out_ref); - src_offset[dim_xyz_index] += - dim_at(graph.sizes_of(out_ref), packed_dim_index); - } - } else { - // set destination offset w as channel size of the output tensor if - // splitting channels - dst_offset[3] = is_splitting_channel ? out_channel_size : 0; - add_copy_offset_node( - graph, in, range, src_offset, dst_offset, out_ref, false, true); - src_offset[dim_xyz_index] += - is_splitting_channel ? out_channel_size : range[dim_xyz_index]; - } +void add_split_node( + ComputeGraph& graph, + const ValueRef input, + const std::vector& split_sizes, + const int64_t dim, + const ValueRef out, + const int split_idx) { + std::string kernel_name = "split"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(input)}; + + int64_t dim_whcn = nchw_dim_to_whcn_dim(dim, graph.dim_of(input)); + + // Calculate the offset for this split by summing previous split sizes + int64_t split_offset = 0; + for (int i = 0; i < split_idx; i++) { + split_offset += split_sizes[i]; } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {utils::safe_downcast(dim_whcn), + static_cast(split_idx), + static_cast(split_offset)}, + // Resize Args + {}, + // Resizing Logic + nullptr)); } -void add_split_with_sizes_default_node( +void add_split_with_sizes_node( ComputeGraph& graph, - ValueRef in, - ValueRef split_sizes_ref, - ValueRef dim_ref, - ValueRef out) { - int64_t dim = graph.extract_scalar(dim_ref); - std::vector split_sizes = *(graph.get_int_list(split_sizes_ref)); + const ValueRef input, + const std::vector& split_sizes, + const int64_t dim, + const ValueRef out_list_ref) { + const ValueListPtr out_list = graph.get_value_list(out_list_ref); + + VK_CHECK_COND(out_list->size() == split_sizes.size()); - add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); + // Dispatch a shader for each output tensor + for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) { + const ValueRef out_ref = out_list->at(split_idx); + add_split_node(graph, input, split_sizes, dim, out_ref, split_idx); + } } void split_with_sizes_copy_default( ComputeGraph& graph, const std::vector& args) { - add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]); -} - -void add_split_tensor_node( - ComputeGraph& graph, - ValueRef in, - ValueRef split_size_ref, - ValueRef dim_ref, - ValueRef out) { - const int64_t split_size = graph.extract_scalar(split_size_ref); - const int64_t dim = graph.extract_scalar(dim_ref); - - const int64_t input_ndim = graph.dim_of(in); - const DimIndex dim_index = dim < 0 ? static_cast(dim) - : static_cast(dim - input_ndim); - const int64_t size = dim_at(graph.sizes_of(in), dim_index); - const std::vector split_sizes(size / split_size, split_size); + ValueRef input = args[0]; + ValueRef split_sizes_ref = args[1]; + ValueRef dim_ref = args[2]; + ValueRef out_list_ref = args[3]; - add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); -} + int64_t dim = graph.extract_scalar(dim_ref); + std::vector split_sizes = *(graph.get_int_list(split_sizes_ref)); -void split_tensor(ComputeGraph& graph, const std::vector& args) { - add_split_tensor_node(graph, args[0], args[1], args[2], args[3]); + add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref); } REGISTER_OPERATORS { VK_REGISTER_OP( aten.split_with_sizes_copy.default, split_with_sizes_copy_default); - VK_REGISTER_OP(aten.split.Tensor, split_tensor); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index f07522d2578..eb03639abf1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -81,9 +81,58 @@ void sym_add(ComputeGraph& graph, const std::vector& args) { new ExecuteNode(resize_sym_add_node, args)); } +void select_as_symint_impl( + ComputeGraph* graph, + const std::vector& unused, + const std::vector& args) { + (void)unused; // Unused parameter + + const ValueRef x = args.at(0); + const ValueRef dim = args.at(1); + const ValueRef index = args.at(2); + const ValueRef out = args.at(3); + + const int64_t dim_val = graph->extract_scalar(dim); + int64_t index_val = graph->extract_scalar(index); + + const std::vector x_sizes = graph->sizes_of(x); + const vkapi::ScalarType x_dtype = graph->dtype_of(x); + + if (index_val < 0) { + index_val += x_sizes[dim_val]; + } + + const StagingPtr x_staging = graph->get_staging(graph->staging_of(x)); + + int32_t x_val; + switch (x_dtype) { + case vkapi::ScalarType::Int: + x_val = x_staging->select_element_at_dim( + x_sizes, dim_val, index_val); + break; + case vkapi::ScalarType::Long: + x_val = static_cast(x_staging->select_element_at_dim( + x_sizes, dim_val, index_val)); + break; + default: + VK_THROW("Unsupported dtype for select_as_symint"); + } + + graph->set_symint(out, x_val); +} + +void select_as_symint(ComputeGraph& graph, const std::vector& args) { + select_as_symint_impl(&graph, {}, args); + + graph.execute_nodes().emplace_back(new ExecuteNode( + select_as_symint_impl, args, {}, "select_as_symint", true)); + graph.set_has_data_dependent_shapes(); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); VK_REGISTER_OP(add, sym_add); + VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp b/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp index 60127ecf9bd..1823271824a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Transfer.cpp @@ -50,15 +50,16 @@ void add_transfer_copy_node( (transfer_type == TransferType::SELECT || graph.is_scalar_or_none(step_ref)); - vkapi::ParamsBindList param_buffers; + vkapi::ParamsBindList param_ubos = {graph.meta_ubo(out), graph.meta_ubo(in)}; + if (!param_is_scalar) { if (transfer_type == TransferType::SELECT) { - param_buffers = { - graph.get_or_create_int_param_buffer(index_or_start_ref, 0)}; + param_ubos.append( + graph.get_or_create_int_param_buffer(index_or_start_ref, 0)); } else { // TransferType::SLICE - param_buffers = { - graph.get_or_create_int_param_buffer(index_or_start_ref, 0), - graph.get_or_create_int_param_buffer(step_ref, 1)}; + param_ubos.append( + graph.get_or_create_int_param_buffer(index_or_start_ref, 0)); + param_ubos.append(graph.get_or_create_int_param_buffer(step_ref, 1)); } } else { transfer_params.index_or_start_ref = @@ -69,18 +70,6 @@ void add_transfer_copy_node( } std::vector push_constants; - push_constants.reserve(graph.is_buffer_storage(out) ? 5 : 3); - - if (graph.is_buffer_storage(out)) { - push_constants.emplace_back(graph.sizes_pc_of(in)); - push_constants.emplace_back(graph.strides_pc_of(out)); - push_constants.emplace_back(graph.strides_pc_of(in)); - push_constants.emplace_back(graph.numel_pc_of(out)); - } else { - push_constants.emplace_back(graph.sizes_pc_of(out)); - push_constants.emplace_back(graph.sizes_pc_of(in)); - } - if (param_is_scalar) { push_constants.emplace_back(&transfer_params, sizeof(transfer_params)); } else { @@ -88,11 +77,6 @@ void add_transfer_copy_node( &transfer_params.dim, sizeof(transfer_params.dim)); } - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(out), - graph.hashed_layout_of(in), - }; - // Determine the shader directly std::string kernel_name; if (transfer_type == TransferType::SELECT) { @@ -115,11 +99,11 @@ void add_transfer_copy_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Parameter buffers - param_buffers, + param_ubos, // Push Constants push_constants, // Specialization Constants - spec_vars, + {}, // Resize Args resize_args, // Resizing Logic diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 36a8ee4c3b1..602fe1ef129 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -67,6 +67,18 @@ void resize_unsqueeze_node( std::vector out_sizes = graph->sizes_of(in); + std::vector unsqueezed_dims; + + if (graph->val_is_int_list(dims_ref)) { + const IntListPtr dims = graph->get_int_list(dims_ref); + for (int64_t d : *dims) { + unsqueezed_dims.push_back(d); + } + } else { + const int64_t dim = graph->extract_scalar(dims_ref); + unsqueezed_dims.push_back(dim); + } + // Insert singleton dimensions at the specified positions for (auto dim : dims_vec) { int64_t d = dim; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8701a6246b0..5e2c898573a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -60,6 +60,16 @@ void resize_view_node( } } +void resize_to_dim_order_copy_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); +} + void add_view_node( ComputeGraph& graph, ValueRef in, @@ -98,6 +108,11 @@ void add_view_copy_buffer_node( std::string kernel_name = "view_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -110,7 +125,41 @@ void add_view_copy_buffer_node( // Push Constants {}, // Specialization Constants + {all_contiguous_int}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)}, + // Push Constants {}, + // Specialization Constants + {all_contiguous_int}, // Resize Args resize_args, // Resizing Logic @@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector& args) { return add_view_node(graph, in, sizes, out); } +void to_dim_order_copy(ComputeGraph& graph, const std::vector& args) { + int args_idx = 0; + const ValueRef in = args.at(args_idx++); + const ValueRef dtype = args.at(args_idx++); + (void)dtype; + const ValueRef layout = args.at(args_idx++); + (void)layout; + const ValueRef device = args.at(args_idx++); + (void)device; + const ValueRef pin_memory = args.at(args_idx++); + (void)pin_memory; + const ValueRef non_blocking = args.at(args_idx++); + (void)non_blocking; + const ValueRef dim_order = args.at(args_idx++); + (void)dim_order; + + const ValueRef out = args.at(args_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out)); + + if (graph.dtype_of(in) == graph.dtype_of(out)) { + return add_view_copy_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); + } + + return add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); + VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index 7a7a8d57742..c8e52492417 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -24,6 +24,19 @@ void add_view_copy_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_buffer compute shader. This can be used to + * implement ops that preserve the "contiguous" indexes of elements between the + * input and output while converting between different data types such as + * view_copy with dtype conversion. + */ +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index b62bf661995..05234c7790f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -69,7 +69,7 @@ template < std::is_integral::value && std::is_signed::value, int>::type = 0> T nchw_dim_to_whcn_dim(const T& nchw_dim, const int64_t ndim) { - return ndim - 1 - nchw_dim; + return ndim - 1 - normalize(nchw_dim, ndim); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl index c1d90fadf7e..e2d198b129f 100644 --- a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl +++ b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_buffer.glsl @@ -12,8 +12,6 @@ ${define_active_storage_type("texture3d")} -#extension GL_EXT_debug_printf : enable - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", "int", "texture3d")} @@ -33,12 +31,6 @@ void main() { // Pack four 8-bit values equal to 1 into a single uint int packed = (1 << 0) | (1 << 8) | (1 << 16) | (1 << 24); - debugPrintfEXT( - "t_out[%i, %i] = %i\\n", - lpos.x, lpos.y, - packed); - - // Placeholder: just copy input to output ivec4 in_texel = ivec4(packed); imageStore(t_out, lpos, in_texel); diff --git a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl index be6717efdaa..80e6fc27909 100644 --- a/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl +++ b/backends/vulkan/test/custom_ops/glsl/packed_int32_canvas_texture3d.glsl @@ -12,8 +12,6 @@ ${define_active_storage_type("texture2d")} -#extension GL_EXT_debug_printf : enable - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", "int", "texture3d")} @@ -33,12 +31,6 @@ void main() { // Pack four 8-bit values equal to 1 into a single uint int packed = (1 << 0) | (1 << 8) | (1 << 16) | (1 << 24); - debugPrintfEXT( - "t_out[%i, %i] = %i\\n", - lpos.x, lpos.y, - packed); - - // Placeholder: just copy input to output ivec4 in_texel = ivec4(packed); imageStore(t_out, lpos, in_texel); diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp index 8762fe4c0d1..bbd4af7579c 100644 --- a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp @@ -47,11 +47,15 @@ TestCase create_test_case_from_config( std::vector input_size = { 1, config.channels.in, config.input_size.h, config.input_size.w}; + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + ValueSpec input_tensor( input_size, input_dtype, storage_type, - utils::kChannelsPacked, + io_memory_layout, DataGenType::RANDOM); if (debugging()) { @@ -139,7 +143,7 @@ TestCase create_test_case_from_config( {1, config.channels.out, H_out, W_out}, input_dtype, storage_type, - utils::kChannelsPacked, + io_memory_layout, DataGenType::ZEROS); // Add all specs to test case for q8ta_q8csw_q8to operation @@ -182,7 +186,8 @@ std::vector generate_quantized_conv2d_easy_cases() { config.op_name = "conv2d_q8ta_q8csw_q8to"; // Test with both storage types and data types for completeness - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; std::vector float_types = {vkapi::kFloat}; // Generate test cases for each combination @@ -341,7 +346,8 @@ std::vector generate_quantized_conv2d_test_cases() { 4}}; // Test with different storage types and data types - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; // Generate test cases for each combination for (auto& config : configs) { @@ -621,7 +627,7 @@ int main(int argc, char* argv[]) { quantized_conv2d_flop_calculator, "QuantizedConv2dQ8ToQ8To", 0, - 10, + 1, ref_fn); return 0; diff --git a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp index 5799bc194c9..eb8e6908060 100644 --- a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp +++ b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp @@ -38,21 +38,17 @@ TestCase create_quantized_add_test_case( // Set the operator name for the test case test_case.set_operator_name("et_vk.add_q8ta_q8ta_q8to.test"); + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + // Input tensor A (float/half) ValueSpec input_a( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::RANDOM); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); // Input tensor B (float/half) ValueSpec input_b( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::RANDOM); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); // Quantization parameters for input A float input_a_scale_val = 0.007843; // 2/255 approximately @@ -81,11 +77,7 @@ TestCase create_quantized_add_test_case( // Output tensor (float/half) ValueSpec output( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::ZEROS); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::ZEROS); // Add all specs to test case for q8ta_q8ta_q8to add operation test_case.add_input_spec(input_a); @@ -119,7 +111,8 @@ std::vector generate_quantized_add_test_cases() { }; // Storage types to test - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; // Data types to test std::vector data_types = {vkapi::kFloat}; diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index dfb9a2865ba..b21a8458a89 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -6,7 +6,6 @@ import itertools - from collections import namedtuple from typing import Callable @@ -1140,6 +1139,55 @@ def get_embedding_inputs(): return test_suite_wpack +@register_test_suite("aten.gather.default") +def get_gather_inputs(): + Test = namedtuple("GatherTest", ["input", "dim", "index"]) + Test.__new__.__defaults__ = (None, None, None) + + test_cases = [ + # Simple 2D case + Test(input=[4, 4], dim=1, index=[[1, 2], [2, 1], [3, 3], [3, 1]]), + # # 1D cases + Test(input=[10], dim=0, index=[0, 2, 5, 7, 9]), + Test(input=[8], dim=0, index=[1, 3, 5]), + # # 2D cases with different dims + Test(input=[5, 8], dim=0, index=[[0, 1], [2, 3], [4, 0]]), + Test( + input=[5, 8], + dim=1, + index=[[0, 2, 4], [1, 3, 5], [6, 7, 0], [1, 2, 3], [4, 5, 6]], + ), + # # 3D cases + Test( + input=[3, 4, 5], + dim=0, + index=[ + [[0, 1, 2, 0, 1], [1, 2, 0, 1, 2], [2, 0, 1, 2, 0], [0, 1, 2, 0, 1]] + ], + ), + Test( + input=[3, 4, 5], + dim=1, + index=[ + [[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1], [3, 0, 1, 2], [0, 1, 2, 3]] + ], + ), + Test( + input=[3, 4, 5], dim=2, index=[[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 0]]] + ), + ] + + test_suite = VkTestSuite( + [tuple(tc) + (False, "false", "false") for tc in test_cases] + ) + + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"] + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] + + return test_suite + + @register_test_suite("aten.unsqueeze_copy.default") def get_unsqueeze_inputs(): test_suite = VkTestSuite( @@ -1470,64 +1518,11 @@ def get_split_with_sizes_inputs(): test_suite.layouts = [ "utils::kWidthPacked", - "utils::kHeightPacked", - "utils::kChannelsPacked", - ] - test_suite.data_gen = "make_seq_tensor" - test_suite.dtypes = ["at::kFloat"] - return test_suite - - -@register_test_suite("aten.split.Tensor") -def get_split_tensor_inputs(): - test_suite = VkTestSuite( - [ - # Split on Width - ((S1, 7, 10, 12), 12, 3), - ((S1, 7, 10, 12), 3, 3), - ((S1, 7, 10, 12), 1, 3), - ((7, 10, 12), 12, 2), - ((7, 10, 12), 3, 2), - ((7, 10, 12), 1, 2), - ((10, 12), 12, 1), - ((10, 12), 3, 1), - ((10, 12), 1, 1), - ((12,), 12, 0), - ((12,), 3, 0), - ((12,), 1, 0), - # Split on Height - ((S1, 7, 12, 8), 12, 2), - ((S1, 7, 12, 8), 3, 2), - ((S1, 7, 12, 8), 1, 2), - ((7, 12, 8), 12, 1), - ((7, 12, 8), 3, 1), - ((7, 12, 8), 1, 1), - ((12, 8), 12, 0), - ((12, 8), 3, 0), - ((12, 8), 1, 0), - # Split on Batch - ((12, 7, 10, 10), 12, 0), - ((12, 7, 10, 10), 3, 0), - ((12, 7, 10, 10), 1, 0), - # Split on Channel - ((7, 15, 10, 10), 15, 1), - ((7, 15, 10, 10), 5, 1), - ((7, 15, 10, 10), 3, 1), - ((7, 15, 10, 10), 1, 1), - ((15, 10, 10), 15, 0), - ((15, 10, 10), 5, 0), - ((15, 10, 10), 3, 0), - ((15, 10, 10), 1, 0), - ] - ) - - test_suite.layouts = [ - "utils::kWidthPacked", - "utils::kHeightPacked", "utils::kChannelsPacked", ] test_suite.data_gen = "make_seq_tensor" test_suite.dtypes = ["at::kFloat"] + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] return test_suite diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp deleted file mode 100644 index 3b1094a1e84..00000000000 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ /dev/null @@ -1,786 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -std::tuple choose_qparams_tensor_out( - const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ET_UNUSED double eps, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out); - -std::tuple choose_qparams_per_token_asymmetric_out( - const Tensor& input, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out); - -// Wrapper function for choose_qparams_tensor_out without context -Tensor& choose_qparams_tensor_out_no_context( - const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ET_UNUSED double eps, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out) { - torch::executor::native::choose_qparams_tensor_out( - input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); - return scale_out; -} - -// Wrapper function for choose_qparams_per_token_asymmetric_out without context -Tensor& choose_qparams_per_token_asymmetric_out_no_context( - const Tensor& input, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out) { - torch::executor::native::choose_qparams_per_token_asymmetric_out( - input, dtype, scale_out, zero_point_out); - return scale_out; -} - -// ATen wrapper for choose_qparams_tensor -std::tuple choose_qparams_tensor_aten( - const at::Tensor& input, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); - auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); - double eps = 1e-7; - - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - // Use WRAP_TO_ATEN with the wrapper function - WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) - (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); - - return {scale_out, zero_point_out}; -} - -// ATen wrapper for choose_qparams_per_token_asymmetric -std::tuple choose_qparams_per_token_asymmetric_aten( - const at::Tensor& input, - at::ScalarType dtype) { - // Calculate output sizes for scale and zero_point tensors - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - auto scale_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); - auto zero_point_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); - - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - // Use WRAP_TO_ATEN with the wrapper function - WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) - (input, et_dtype, scale_out, zero_point_out); - - return {scale_out, zero_point_out}; -} - -} // namespace native -} // namespace executor -} // namespace torch - -// -// Reference Implementation -// - -/* - * Reference implementation of choose_qparams_tensor - */ -std::tuple choose_qparams_tensor_reference_impl( - const at::Tensor& input, - int64_t quant_min, - int64_t quant_max) { - // Create output tensors - at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_out = - at::empty({}, at::device(at::kCPU).dtype(at::kLong)); - - // Find min and max values in the input tensor - float min_val = input.min().item(); - float max_val = input.max().item(); - - // Extend the [min, max] interval to ensure it contains 0 - min_val = std::min(min_val, 0.f); - max_val = std::max(max_val, 0.f); - - // Calculate scale - double scale = - (static_cast(max_val) - min_val) / (quant_max - quant_min); - - // Handle small scale - constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; - if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { - scale = 0.1; - } - - if (scale < SMALL_SCALE_THRESHOLD) { - float org_scale = scale; - scale = SMALL_SCALE_THRESHOLD; - // Adjust min and max based on new scale - if (min_val == 0.0f) { - max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else if (max_val == 0.0f) { - min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; - min_val *= amplifier; - max_val *= amplifier; - } - } - - // Calculate zero point - double zero_point_from_min = quant_min - min_val / static_cast(scale); - double zero_point_from_max = quant_max - max_val / static_cast(scale); - double zero_point_from_min_error = - std::abs(quant_min) - std::abs(min_val / static_cast(scale)); - double zero_point_from_max_error = - std::abs(quant_max) - std::abs(max_val / static_cast(scale)); - double initial_zero_point = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Nudge zero point to be an integer - int64_t nudged_zero_point = 0; - if (initial_zero_point < quant_min) { - nudged_zero_point = quant_min; - } else if (initial_zero_point > quant_max) { - nudged_zero_point = quant_max; - } else { - nudged_zero_point = std::nearbyint(static_cast(initial_zero_point)); - } - - // Set output values - use item_mutable() for scalar tensors - scale_out.fill_(scale); - zero_point_out.fill_(nudged_zero_point); - - return std::make_tuple(scale_out, zero_point_out); -} - -/* - * Reference implementation of choose_qparams_per_token_asymmetric - */ -std::tuple -choose_qparams_per_token_asymmetric_reference_impl( - const at::Tensor& input, - at::ScalarType dtype) { - // For per-token quantization, we need to compute scale and zero_point for - // each token - int64_t quant_min = -128; - int64_t quant_max = 127; - - // Calculate output sizes - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - // Create output tensors - at::Tensor scale_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); - - // Calculate number of tokens - int64_t num_tokens = 1; - for (int64_t i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - - // Process each token - for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) { - at::Tensor token = reshaped_input[token_idx]; - - // Find min and max values for this token - float min_val = token.min().item(); - float max_val = token.max().item(); - - // Extend the [min, max] interval to ensure it contains 0 - min_val = std::min(min_val, 0.f); - max_val = std::max(max_val, 0.f); - - // Calculate scale - double scale = - (static_cast(max_val) - min_val) / (quant_max - quant_min); - - // Handle small scale - constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; - if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { - scale = 0.1; - } - - if (scale < SMALL_SCALE_THRESHOLD) { - float org_scale = scale; - scale = SMALL_SCALE_THRESHOLD; - // Adjust min and max based on new scale - if (min_val == 0.0f) { - max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else if (max_val == 0.0f) { - min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; - min_val *= amplifier; - max_val *= amplifier; - } - } - - // Calculate zero point - double zero_point_from_min = - quant_min - min_val / static_cast(scale); - double zero_point_from_max = - quant_max - max_val / static_cast(scale); - double zero_point_from_min_error = - std::abs(quant_min) - std::abs(min_val / static_cast(scale)); - double zero_point_from_max_error = - std::abs(quant_max) - std::abs(max_val / static_cast(scale)); - double initial_zero_point = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Nudge zero point to be an integer - int64_t nudged_zero_point = 0; - if (initial_zero_point < quant_min) { - nudged_zero_point = quant_min; - } else if (initial_zero_point > quant_max) { - nudged_zero_point = quant_max; - } else { - nudged_zero_point = - std::nearbyint(static_cast(initial_zero_point)); - } - - // Set output values for this token - use index_put_ for safety - scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale); - zero_point_out.view({num_tokens, 1}) - .index_put_({token_idx, 0}, nudged_zero_point); - } - - return std::make_tuple(scale_out, zero_point_out); -} - -// Forward declaration of implementation functions -void test_vulkan_choose_qparams_tensor_impl( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_choose_qparams_per_token_asymmetric_impl( - const std::vector& input_sizes, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_choose_qparams_tensor( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Test with buffer storage - test_vulkan_choose_qparams_tensor_impl( - input_sizes, - quant_min, - quant_max, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Test with texture storage - test_vulkan_choose_qparams_tensor_impl( - input_sizes, - quant_min, - quant_max, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_choose_qparams_per_token_asymmetric( - const std::vector& input_sizes, - at::ScalarType dtype) { - // Test with buffer storage - test_vulkan_choose_qparams_per_token_asymmetric_impl( - input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); - - // Test with texture storage - test_vulkan_choose_qparams_per_token_asymmetric_impl( - input_sizes, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_choose_qparams_tensor( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - choose_qparams_tensor_reference_impl(input, quant_min, quant_max); - - // Get implementation output - auto [impl_scale, impl_zero_point] = - torch::executor::native::choose_qparams_tensor_aten( - input, quant_min, quant_max, dtype); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale, impl_scale); - const bool zero_point_correct = - at::equal(reference_zero_point, impl_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "implementation scale:" << std::endl; - std::cout << impl_scale << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "implementation zero_point:" << std::endl; - std::cout << impl_zero_point << std::endl; - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -void test_vulkan_choose_qparams_tensor_impl( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - torch::executor::native::choose_qparams_tensor_aten( - input, quant_min, quant_max, dtype); - - // Build Vulkan choose_qparams_tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - // Output tensors - const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); - const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - - // Create output tuple - const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); - - // Add eps and dtype parameters to match ATen signature - const ValueRef r_eps = graph.add_scalar(6.1e-5); - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") - (graph, - { - r_input.value, - r_quant_min, - r_quant_max, - r_eps, - r_dtype, - r_out_tuple, - }); - - ValueRef staging_scale = graph.set_output_tensor(r_scale); - ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); - - graph.prepare(); - - graph.prepack(); - - // Run Vulkan choose_qparams_tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - // Create output tensors to hold the results - use types that match GPU output - at::Tensor vk_scale = - at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous(); - at::Tensor vk_zero_point = - at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous(); - - // Copy results from GPU to CPU - graph.copy_from_staging( - staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); - graph.copy_from_staging( - staging_zero_point, - vk_zero_point.mutable_data_ptr(), - vk_zero_point.numel()); - - // Convert reference values to match Vulkan output types for comparison - at::Tensor reference_scale_float = reference_scale.to(at::kFloat); - at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale_float, vk_scale); - const bool zero_point_correct = - at::equal(reference_zero_point_int, vk_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - // make sure that there arent a ton of elements in the input tensor - if (input.numel() < 100) { - std::cout << "input:" << std::endl; - std::cout << input << "\n" << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "vulkan scale:" << std::endl; - std::cout << vk_scale << "\n" << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "vulkan zero_point:" << std::endl; - std::cout << vk_zero_point << std::endl; - } - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { - test_reference_choose_qparams_tensor( - {2, 3, 4}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {5, 3, 2, 4}, // input sizes - 0, // quant_min - 255, // quant_max - at::kByte); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {5, 5}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {12, 8, 2}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {10, 10, 6, 4}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -void test_reference_choose_qparams_per_token_asymmetric( - const std::vector& input_sizes, - at::ScalarType dtype) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - choose_qparams_per_token_asymmetric_reference_impl(input, dtype); - - // Get implementation output - auto [impl_scale, impl_zero_point] = - torch::executor::native::choose_qparams_per_token_asymmetric_aten( - input, dtype); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale, impl_scale); - const bool zero_point_correct = - at::equal(reference_zero_point, impl_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "implementation scale:" << std::endl; - std::cout << impl_scale << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "implementation zero_point:" << std::endl; - std::cout << impl_zero_point << std::endl; - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -void test_vulkan_choose_qparams_per_token_asymmetric_impl( - const std::vector& input_sizes, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Calculate output sizes - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - // Get reference output - auto [reference_scale, reference_zero_point] = - torch::executor::native::choose_qparams_per_token_asymmetric_aten( - input, dtype); - - // Build Vulkan choose_qparams_per_token_asymmetric graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - // Output tensors - const ValueRef r_scale = - graph.add_tensor(output_sizes, vkapi::kFloat, out_storage); - const ValueRef r_zero_point = - graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - - // Create output tuple - const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); - - // Add dtype parameter to match ATen signature - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN( - "quantized_decomposed.choose_qparams_per_token_asymmetric.default") - (graph, - { - r_input.value, - r_dtype, - r_out_tuple, - }); - - ValueRef staging_scale = graph.set_output_tensor(r_scale); - ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); - - graph.prepare(); - - graph.prepack(); - - // Run Vulkan choose_qparams_per_token_asymmetric - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - // Create output tensors to hold the results - use types that match GPU output - at::Tensor vk_scale = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat)) - .contiguous(); - at::Tensor vk_zero_point = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt)) - .contiguous(); - - // Copy results from GPU to CPU - graph.copy_from_staging( - staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); - graph.copy_from_staging( - staging_zero_point, - vk_zero_point.mutable_data_ptr(), - vk_zero_point.numel()); - - // Convert reference values to match Vulkan output types for comparison - at::Tensor reference_scale_float = reference_scale.to(at::kFloat); - at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale_float, vk_scale); - const bool zero_point_correct = - at::equal(reference_zero_point_int, vk_zero_point); - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - if (input.numel() < 100) { - std::cout << "input:" << std::endl; - std::cout << input << "\n" << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "vulkan scale:" << std::endl; - std::cout << vk_scale << "\n" << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "vulkan zero_point:" << std::endl; - std::cout << vk_zero_point << std::endl; - } - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -TEST( - VulkanChooseQparamsTest, - test_reference_choose_qparams_per_token_asymmetric_int8) { - test_reference_choose_qparams_per_token_asymmetric( - {2, 3, 4}, // input sizes (2*3=6 tokens) - at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); -} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp deleted file mode 100644 index 9fca2c632d3..00000000000 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ /dev/null @@ -1,2492 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -Tensor& dequantize_per_tensor_out( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -Tensor& dequantize_per_token_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out); - -Tensor& dequantize_per_channel_out( - const Tensor& input, - const Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -Tensor& dequantize_per_tensor_tensor_args_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -// Wrapper function for dequantize_per_tensor_out without context -Tensor& dequantize_per_tensor_out_no_context( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - -// Wrapper function for dequantize_per_token_out without context -Tensor& dequantize_per_token_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_token_out( - input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); -} - -// Wrapper function for dequantize_per_channel_out without context -Tensor& dequantize_per_channel_out_no_context( - const Tensor& input, - const Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_channel_out( - input, - scale, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - out); -} - -// Wrapper function for dequantize_per_tensor_tensor_args_out without context -Tensor& dequantize_per_tensor_tensor_args_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - -// ATen wrapper for dequantize_per_tensor -at::Tensor dequantize_per_tensor_aten( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) - (input, - scale, - zero_point, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_token -at::Tensor dequantize_per_token_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) - (input, - scale, - zero_points, - quant_min, - quant_max, - et_dtype, - et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_channel -at::Tensor dequantize_per_channel_aten( - const at::Tensor& input, - const at::Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) - (input, - scale, - zero_points, - axis, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_tensor with tensor args -at::Tensor dequantize_per_tensor_tensor_args_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_tensor_tensor_args_out_no_context, 7) - (input, - scale, - zero_point, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch - -void check_dequantize_args( - int64_t quant_min, - int64_t quant_max, - c10::ScalarType in_dtype, - c10::ScalarType out_dtype) { - using namespace vkcompute; - - // Check that quant_min <= quant_max - VK_CHECK_COND( - quant_min <= quant_max, - "quant_min must be <= quant_max, got quant_min: ", - quant_min, - " quant_max: ", - quant_max); - - // Check that input dtype is a quantized type - switch (in_dtype) { - case c10::kByte: - case c10::kChar: - case c10::kShort: - case c10::kInt: - case c10::kLong: - break; - default: - VK_THROW( - "Unsupported input dtype: ", - scalar_type_name(in_dtype), - " (", - static_cast(in_dtype), - ")"); - } - - // Check that output dtype is a floating point type - switch (out_dtype) { - case c10::kHalf: - case c10::kFloat: - case c10::kDouble: - break; - default: - VK_THROW( - "Unsupported output dtype: ", - scalar_type_name(out_dtype), - " (", - static_cast(out_dtype), - ")"); - } -} - -/** - * Helper function to validate dequantize_per_channel arguments - * Similar to the validation in quantize_test.cpp - */ -void check_dequantize_per_channel_args( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis) { - // Normalize axis - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes.size(); - } - - ASSERT_GE(normalized_axis, 0) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be >= 0"; - - ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be < input.dim() " << input_sizes.size(); - - int64_t num_channels = input_sizes[normalized_axis]; - - ASSERT_EQ(num_channels, static_cast(scales.size())) - << "Expected scales.size() to match input.size(axis) (" << num_channels - << "), but got " << scales.size(); - - ASSERT_EQ(num_channels, static_cast(zero_points.size())) - << "Expected zero_points.size() to match input.size(axis) (" - << num_channels << "), but got " << zero_points.size(); -} - -// -// Reference Implementation -// - -/* - * Reference implementation of dequantize_per_tensor - */ -at::Tensor dequantize_per_tensor_reference_impl( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, out_dtype); - - // Dequantize the input tensor - at::Tensor flat_input = input.flatten(); - at::Tensor flat_out = out.flatten(); - - // Store casted values to avoid repeated casting - const int32_t zero_point_int32 = static_cast(zero_point); - const float scale_float = static_cast(scale); - - for (int i = 0; i < flat_input.numel(); i++) { - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kChar) { - int8_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kShort) { - int16_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kInt) { - int32_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kLong) { - int64_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } - - // Store result based on output dtype - if (out_dtype == at::kFloat) { - flat_out[i] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - flat_out[i] = dequantized_value; - } else if (out_dtype == at::kHalf) { - flat_out[i] = static_cast(dequantized_value); - } - } - - return out.reshape(input.sizes()); -} - -/* - * Reference implementation of dequantize_per_token - */ -at::Tensor dequantize_per_token_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, out_dtype); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scale and zero_point - // tensors - assert(num_tokens == scale.numel()); - assert(num_tokens == zero_point.numel()); - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); - - // Dequantize each token separately - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - // Get scale and zero_point for this token - float token_scale = scale[token_idx].item(); - int64_t token_zero_point = zero_point[token_idx].item(); - - // Store casted values to avoid repeated casting - const int32_t token_zero_point_int32 = - static_cast(token_zero_point); - - // Dequantize the token - for (int i = 0; i < input.size(-1); i++) { - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kChar) { - int8_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kShort) { - int16_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kInt) { - int32_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kLong) { - int64_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else { - throw std::runtime_error("Unsupported input dtype"); - } - - // Store result based on output dtype - if (out_dtype == at::kFloat) { - reshaped_out[token_idx][i] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - reshaped_out[token_idx][i] = dequantized_value; - } else if (out_dtype == at::kHalf) { - reshaped_out[token_idx][i] = static_cast(dequantized_value); - } - } - } - - return out; -} - -/* - * Reference implementation of dequantize_per_channel - */ -at::Tensor dequantize_per_channel_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const std::optional& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Normalize axis to handle negative values - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - // Create output tensor with the same shape as input but with target dtype - at::Tensor output = at::empty_like(input, out_dtype); - - // Get the number of channels along the quantization axis - int64_t num_channels = input.size(normalized_axis); - - // Calculate strides for efficient indexing - std::vector input_strides; - std::vector input_sizes; - for (int64_t i = 0; i < input.dim(); i++) { - input_sizes.push_back(input.size(i)); - input_strides.push_back(input.stride(i)); - } - - // Get data pointers - const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = nullptr; - if (zero_point.has_value()) { - zero_point_data = zero_point.value().const_data_ptr(); - } - - // Iterate through all elements in the tensor - int64_t total_elements = input.numel(); - - // Helper lambda to convert flat index to multi-dimensional coordinates - auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { - int64_t remaining = flat_idx; - for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { - coords[dim] = remaining % input_sizes[dim]; - remaining /= input_sizes[dim]; - } - }; - - // Process each element - std::vector coords(input.dim()); - for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { - // Convert flat index to coordinates - flat_to_coords(flat_idx, coords); - - // Get the channel index for this element - int64_t channel_idx = coords[normalized_axis]; - - // Get the quantization parameters for this channel - double channel_scale = scale_data[channel_idx]; - int64_t channel_zero_point = 0; - if (zero_point_data != nullptr) { - channel_zero_point = zero_point_data[channel_idx]; - } - - // Store casted values to avoid repeated casting - const int32_t channel_zero_point_int32 = - static_cast(channel_zero_point); - const float channel_scale_float = static_cast(channel_scale); - - // Get the input value and dequantize - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kChar) { - int8_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kShort) { - int16_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kInt) { - int32_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kLong) { - int64_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else { - throw std::runtime_error("Unsupported input dtype"); - } - - // Store the result based on output dtype - if (out_dtype == at::kFloat) { - output.flatten()[flat_idx] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - output.flatten()[flat_idx] = dequantized_value; - } else if (out_dtype == at::kHalf) { - output.flatten()[flat_idx] = static_cast(dequantized_value); - } - } - - return output; -} - -// Forward declaration of implementation functions -void test_vulkan_dequantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_dequantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_dequantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_tensor_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_dequantize_per_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Get reference output - at::Tensor reference_out = dequantize_per_tensor_reference_impl( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(reference_out, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "implementation:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_uint8_to_float) { - test_reference_dequantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int8_to_float) { - test_reference_dequantize_per_tensor( - {3, 4, 5}, // input sizes - 0.05, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_float) { - test_reference_dequantize_per_tensor( - {4, 6, 2}, // input sizes - 0.2, // scale - 2, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_uint8_to_half) { - test_reference_dequantize_per_tensor( - {7, 4}, // input sizes - 0.1, // scale - 10, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype (uint8) - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_half) { - test_reference_dequantize_per_tensor( - {2, 6, 5}, // input sizes - 0.3, // scale - -10, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -// No Vulkan tests for quantized_decomposed.dequantize_per_tensor.default -// because it is not going to be implemented in Vulkan since we will -// be handling any future calls to this op via the export stage - -void test_reference_dequantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - // Create input tensor with quantized values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - float step = 1.0f; - if (input.size(-1) > 1) { - step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); - } - - for (int i = 0; i < input.size(-1); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } - } - } - - // Reshape back to original dimensions - input = reshaped_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = dequantize_per_token_reference_impl( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(reference_out, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "implementation:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_dequantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - // Create input tensor with quantized values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - float step = 1.0f; - if (input.size(-1) > 1) { - step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); - } - - for (int i = 0; i < input.size(-1); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } - } - } - - // Reshape back to original dimensions - input = reshaped_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_token graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_uint8_to_float) { - std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; - std::vector zero_points = {5, 10, 15, 20, 25, 30}; - - test_reference_dequantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int8_to_float) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {2, 2, 5}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int32_to_float) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {2, 2, 10}, // input sizes (2*2=4 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int8_to_half) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {4, 1, 5}, // input sizes (4*1=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype (int8) - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int32_to_half) { - std::vector scales = {0.05, 0.1}; - std::vector zero_points = {0, -5}; - - test_reference_dequantize_per_token( - {2, 2}, // input sizes (2 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; - std::vector zero_points = {5, 10, 15, 20, 25, 30}; - - test_vulkan_dequantize_per_token( - {2, 3, 6}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.0}; - std::vector zero_points = {10, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int32_to_float) { - std::vector scales = { - 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; - std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; - - test_vulkan_dequantize_per_token( - {2, 2, 2, 12}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.2}; - std::vector zero_points = {2, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int32_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - // Use much smaller scales to avoid overflow to infinity in half precision - // Half precision max value is ~65504, so with int32 values around 2e9, - // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow - std::vector scales = {1e-5, 2e-5, 1.5e-5}; - std::vector zero_points = {20, -15, 1}; - - test_vulkan_dequantize_per_token( - {3, 6}, // input sizes (3 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.001}; - std::vector zero_points = {10, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype -} - -void test_reference_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create input tensor with quantized values - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor my_ref = dequantize_per_channel_reference_impl( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Get implementation output - at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(my_ref, cpu_ref); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "cpu_ref:" << std::endl; - std::cout << cpu_ref << std::endl; - std::cout << "my_ref:" << std::endl; - std::cout << my_ref << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_dequantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create random float tensor - at::Tensor float_x = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); - - // Map the dtype to the corresponding quantized type and quantize the float - // tensor - c10::ScalarType qtype; - at::Tensor adjusted_zero_points = zero_point_tensor; - - if (dtype == at::kByte) { - qtype = c10::kQUInt8; - // ATEN ONLY: Adjust zero points for unsigned types (must be non-negative) - adjusted_zero_points = at::clamp_min(zero_point_tensor, 0); - } else if (dtype == at::kChar) { - qtype = c10::kQInt8; - } else if (dtype == at::kInt) { - qtype = c10::kQInt32; - } else { - std::cout << "invalid dtype for ATEN: " << dtype << std::endl; - std::cout << " --> Delegating to c10::kQInt32" << std::endl; - qtype = c10::kQInt32; - } - - // Normalize axis for ATen (ATen doesn't handle negative axes in - // quantize_per_channel) - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes_int64.size(); - } - - // Quantize using ATen - at::Tensor quantized_aten = at::quantize_per_channel( - float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype); - - // Get ATen dequantized output - at::Tensor aten_out = at::dequantize(quantized_aten).to(out_dtype); - - // Extract the quantized values (int_repr) to use with our implementations - at::Tensor quantized_input = quantized_aten.int_repr().to(dtype); - - // Get reference output using - // torch::executor::native::dequantize_per_channel_aten - at::Tensor reference_out = - torch::executor::native::dequantize_per_channel_aten( - quantized_input, - scale_tensor.to(at::kDouble), - zero_point_tensor.to(at::kLong), - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_channel graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - // Add tensors to graph - IOValueRef r_input = graph.add_input_tensor( - quantized_input.sizes().vec(), - from_at_scalartype(quantized_input.scalar_type()), - in_storage); - - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - - IOValueRef r_zero_point = graph.add_input_tensor( - adjusted_zero_points.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - ValueRef r_out = graph.add_tensor( - quantized_input.sizes().vec(), - from_at_scalartype(out_dtype), - out_storage); - - const ValueRef r_axis = graph.add_scalar(axis); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - const ValueRef r_output_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_channel.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_axis, - r_quant_min, - r_quant_max, - r_dtype, - r_output_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, - quantized_input.const_data_ptr(), - quantized_input.numel()); - - // copy scale tensor to GPU - graph.copy_into_staging( - r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); - - // copy zero_point tensor to GPU - graph.copy_into_staging( - r_zero_point.staging, - zero_point_tensor.const_data_ptr(), - zero_point_tensor.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - std::cout << " storage: " << in_storage << std::endl; - std::cout << std::endl; - - std::cout << "\033[91m quantized_input: \033[0m" << std::endl; - std::cout << quantized_input << std::endl; - std::cout << "\033[91m aten: \033[0m" << std::endl; - std::cout << aten_out << std::endl; - std::cout << "\033[91m reference: \033[0m" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "\033[91m vulkan: \033[0m" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { - std::vector scales = {0.1, 0.2, 0.3}; - std::vector zero_points = {0, 5, -2}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { - std::vector scales = {0.1, 0.2, 0.00002}; - std::vector zero_points = {0, 5, -4}; - - test_reference_dequantize_per_channel( - {3, 4, 2, 5}, // input sizes - scales, - zero_points, - 0, // axis - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, - at::kFloat); -} - -// END OF REFERENCE TESTS - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis0) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(9, 0.1f); - std::vector zero_points(9, 2); - - // 1D Tensor - test_vulkan_dequantize_per_channel( - {9}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 2D Tensor - test_vulkan_dequantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 7, 11}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 17, 5, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 17, 5, 9}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis1) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(14, 0.001f); - std::vector zero_points(14, -5); - - // 2D Tensor - test_vulkan_dequantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 5}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {9, 7, 14, 5}, // input sizes - scales, - zero_points, - -2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis2) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(11, 0.5f); - std::vector zero_points(11, 12); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {9, 11, 14, 5}, // input sizes - scales, - zero_points, - -3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis3) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(7, 0.5f); - std::vector zero_points(7, 12); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 7}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {7, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; - std::vector zero_points = {0, 5, -5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_8bit_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kHalf); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_8bit_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kDouble); -} - -void test_vulkan_dequantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Create scale and zero_point as tensors (single element tensors) - at::Tensor scale_tensor = - at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output using tensor variant - at::Tensor reference_out = - torch::executor::native::dequantize_per_tensor_tensor_args_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_tensor.tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - - // Add scale and zero_point as tensor inputs (buffer storage, width packed) - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - const ValueRef r_out_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.tensor") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Run Vulkan dequantize_per_tensor.tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4, 12}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int32_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3}, // input sizes - 0.01, // scale - 12, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_uint8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {3, 4}, // input sizes - 0.3, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.03, // scale - -2, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype -} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp deleted file mode 100644 index 86eebcf9b14..00000000000 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ /dev/null @@ -1,2188 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include -#include - -float eps = 1e-7; - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -Tensor& quantize_per_tensor_out( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_token_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_channel_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_tensor_tensor_args_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -// Wrapper function for quantize_per_tensor_out without context -Tensor& quantize_per_tensor_out_no_context( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_token_out without context -Tensor& quantize_per_token_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_token_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_channel_out without context -Tensor& quantize_per_channel_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_channel_out( - input, scale, zero_point, axis, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_tensor_tensor_args_out without context -Tensor& quantize_per_tensor_tensor_args_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// ATen wrapper for quantize_per_tensor -at::Tensor quantize_per_tensor_aten( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_token -at::Tensor quantize_per_token_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_channel -at::Tensor quantize_per_channel_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) - (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_tensor with tensor args -at::Tensor quantize_per_tensor_tensor_args_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_tensor_tensor_args_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch - -void check_quantize_args( - int64_t quant_min, - int64_t quant_max, - c10::ScalarType out_dtype) { - using namespace vkcompute; - int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; - switch (out_dtype) { - case c10::kByte: - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - break; - case c10::kChar: - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - break; - case c10::kBits16: - case c10::kUInt16: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - case c10::kShort: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - case c10::kInt: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - default: - VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); - } - VK_CHECK_COND( - quant_min >= quant_min_lower_bound, - "quant_min out of bound for dtype, expected quant_min_lower_bound: ", - quant_min_lower_bound, - " actual quant_min: ", - quant_min); - - VK_CHECK_COND( - quant_max <= quant_max_upper_bound, - "quant_max out of bound for dtype, expected quant_max_upper_bound: ", - quant_max_upper_bound, - " actual quant_max: ", - quant_max); -} - -/** - * Helper function to validate quantize_per_channel arguments - * Similar to the validation in op_quantize.cpp - */ -void check_quantize_per_channel_args( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis) { - // Normalize axis - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes.size(); - } - - ASSERT_GE(normalized_axis, 0) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be >= 0"; - - ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be < input.dim() " << input_sizes.size(); - - int64_t num_channels = input_sizes[normalized_axis]; - - ASSERT_EQ(num_channels, static_cast(scales.size())) - << "Expected scales.size() to match input.size(axis) (" << num_channels - << "), but got " << scales.size(); - - ASSERT_EQ(num_channels, static_cast(zero_points.size())) - << "Expected zero_points.size() to match input.size(axis) (" - << num_channels << "), but got " << zero_points.size(); -} - -// -// Reference Implementation -// - -/* - * Reference implementation of quantize_per_tensor - */ -at::Tensor quantize_per_tensor_reference_impl( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, dtype); - - // Quantize the input tensor - float inv_scale = 1.0 / scale; - - // Iterate through the tensor and quantize each element - at::Tensor float_input = input.to(at::kFloat); - at::Tensor float_values = float_input.flatten(); - - auto out_flat = out.flatten(); - - for (int i = 0; i < float_values.numel(); i++) { - float value = float_values[i].item(); - int64_t qvalue = zero_point + std::nearbyint(inv_scale * value); - - qvalue = std::max(qvalue, quant_min); - qvalue = std::min(qvalue, quant_max); - - if (dtype == at::kByte) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - out_flat[i] = static_cast(qvalue); - } - } - - return out.reshape(input.sizes()); -} - -/* - * Reference implementation of quantize_per_token - */ -at::Tensor quantize_per_token_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, dtype); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scale and zero_point - // tensors - assert(num_tokens == scale.numel()); - assert(num_tokens == zero_point.numel()); - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); - - // Quantize each token separately - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - // Use float for scale since Vulkan doesn't support double - float token_scale = scale[token_idx].item(); - // Use int for zero_point since Vulkan doesn't support int64_t - int token_zero_point = zero_point[token_idx].item(); - - float inv_scale = 1.0 / token_scale; - - // Quantize the token - for (int i = 0; i < input.size(-1); i++) { - float value = reshaped_input[token_idx][i].item(); - int qvalue = token_zero_point + std::nearbyint(inv_scale * value); - - qvalue = std::max(qvalue, quant_min); - qvalue = std::min(qvalue, quant_max); - - if (dtype == at::kByte) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } - } - } - - return out; -} - -/* - * Reference implementation of quantize_per_channel - */ -at::Tensor quantize_per_channel_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Normalize axis to handle negative values - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - // Create output tensor with the same shape as input but with target dtype - at::Tensor output = at::empty_like(input, dtype); - - // Get the number of channels along the quantization axis - int64_t num_channels = input.size(normalized_axis); - - // Calculate strides for efficient indexing - std::vector input_strides; - std::vector input_sizes; - for (int64_t i = 0; i < input.dim(); i++) { - input_sizes.push_back(input.size(i)); - input_strides.push_back(input.stride(i)); - } - - // Get data pointers - const float* input_data = input.const_data_ptr(); - const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = zero_point.const_data_ptr(); - - // Iterate through all elements in the tensor - int64_t total_elements = input.numel(); - - // Helper lambda to convert flat index to multi-dimensional coordinates - auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { - int64_t remaining = flat_idx; - for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { - coords[dim] = remaining % input_sizes[dim]; - remaining /= input_sizes[dim]; - } - }; - - // Process each element - std::vector coords(input.dim()); - for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { - // Convert flat index to coordinates - flat_to_coords(flat_idx, coords); - - // Get the channel index for this element - int64_t channel_idx = coords[normalized_axis]; - - // Get the quantization parameters for this channel - double channel_scale = scale_data[channel_idx]; - int64_t channel_zero_point = zero_point_data[channel_idx]; - - // Get the input value - float input_value = input_data[flat_idx]; - - // Apply quantization formula: round(input / scale) + zero_point - float inv_scale = 1.0f / static_cast(channel_scale); - int64_t quantized_value = static_cast( - static_cast(channel_zero_point) + - std::nearbyint(static_cast(inv_scale * input_value))); - - // Clamp to quantization bounds - quantized_value = std::max(quantized_value, quant_min); - quantized_value = std::min(quantized_value, quant_max); - - // Store the result based on output dtype - switch (dtype) { - case at::kByte: { - uint8_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kChar: { - int8_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kShort: { - int16_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kInt: { - int32_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - default: - assert(false && "Unsupported output dtype"); - } - } - - return output; -} - -// Forward declaration of implementation functions -void test_vulkan_quantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_quantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_quantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_quantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - test_vulkan_quantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_tensor_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_quantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_quantize_per_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0f / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - scale = scale < eps ? eps : scale; - - // Get reference output - at::Tensor reference_out = quantize_per_tensor_reference_impl( - input, scale, zero_point, quant_min, quant_max, dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor impl_int = impl_out.to(at::kInt); - - const bool output_correct = at::equal(reference_int, impl_int); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - impl_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "my_reference:" << std::endl; - std::cout << impl_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_float_to_int8) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_float_to_int32) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.04, // scale - 5, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_half_to_uint8) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.2, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_half_to_int32) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kHalf, - at::kInt); -} - -// No Vulkan tests for quantized_decomposed.quantize_per_tensor.default -// because it is not going to be implemented in Vulkan since we will -// be handling any future calls to this op via the export stage - -void test_reference_quantize_per_token( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0 / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scales and zero_points - ASSERT_EQ(num_tokens, pre_scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = quantize_per_token_reference_impl( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::quantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor impl_int = impl_out.to(at::kInt); - - const bool output_correct = at::equal(reference_int, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "my_reference:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_quantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, pre_scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create input tensor with random values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output to show what we would compare against - at::Tensor reference_out = torch::executor::native::quantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_float_to_int8) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_float_to_int32) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_half_to_int32) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kHalf, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_half_to_uint8) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5*2=10 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); -} - -TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int32) { - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5*2=10 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int32_small_scales) { - std::vector scales = { - 0, - 2.9387358770557188e-39f, - 1.40129846e-45f, - 1.17549435e-38f, - 0.0000000000001}; - std::vector zero_points = {20, -10, 15, 200, 50}; - - test_vulkan_quantize_per_token( - {5, 2}, // input sizes (3 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(18, 0.1); - std::vector zero_points(18, 5); - - // Alternate scale values - for (size_t i = 0; i < scales.size(); i++) { - scales[i] = (i % 2 == 0) ? 0.3 : -0.5; - } - - test_vulkan_quantize_per_token( - {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) - scales, - zero_points, - 0, // quant_min - 125, // quant_max - at::kFloat, - at::kByte); -} - -TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_vulkan_quantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kHalf, // input dtype - at::kChar); // output dtype -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_double_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_vulkan_quantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kDouble, // input dtype - at::kChar); // output dtype -} - -void test_reference_quantize_per_channel( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0f / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor my_ref = quantize_per_channel_reference_impl( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - // Get implementation output - at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - // Get direct ATen implementation output - c10::ScalarType aten_dtype = dtype; - if (dtype == at::kChar) { - aten_dtype = c10::kQInt8; - } else if (dtype == at::kByte) { - aten_dtype = c10::kQUInt8; - } - - // Normalize axis for ATen (it doesn't handle negative values) - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - at::Tensor aten_ref = at::quantize_per_channel( - input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor my_ref_int = my_ref.to(at::kInt); - at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); - // For quantized tensors, we need to use int_repr() to get the underlying - // integer values - at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); - - const bool output_correct = at::equal(my_ref_int, cpu_ref_int); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "aten_ref:" << std::endl; - std::cout << aten_ref_int << std::endl; - std::cout << "cpu_ref:" << std::endl; - std::cout << cpu_ref_int << std::endl; - std::cout << "my_ref:" << std::endl; - std::cout << my_ref_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_quantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create input tensor with random values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = torch::executor::native::quantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_axis = graph.add_scalar(axis); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_channel.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_axis, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axis0) { - std::vector scales = {0.1, 0.2, 0.3}; - std::vector zero_points = {0, 5, -2}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axis2) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_4D_axis0) { - std::vector scales = {0.1, 0.2, 0.00002}; - std::vector zero_points = {0, 5, -4}; - - test_reference_quantize_per_channel( - {3, 4, 2, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -// END OF REFERENCE TESTS - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis0) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(9, 0.1f); - std::vector zero_points(9, 2); - - // 1D Tensor - test_vulkan_quantize_per_channel( - {9}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 2D Tensor - test_vulkan_quantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 7, 11}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 17, 5, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 17, 5, 9}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis1) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(14, 0.001f); - std::vector zero_points(14, -5); - - // 2D Tensor - test_vulkan_quantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 5}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {9, 7, 14, 5}, // input sizes - scales, - zero_points, - -2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis2) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(11, 0.5f); - std::vector zero_points(11, 12); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {9, 11, 14, 5}, // input sizes - scales, - zero_points, - -3, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis3) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(7, 0.5f); - std::vector zero_points(7, 12); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 7}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {7, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; - std::vector zero_points = {0, 5, -5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_half_to_8bit) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_double_to_8bit) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kDouble, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kDouble, - at::kByte); -} - -void test_vulkan_quantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - scale = scale < eps ? eps : scale; - - // Create scale and zero_point as tensors (single element tensors) - at::Tensor scale_tensor = - at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output using tensor variant - at::Tensor reference_out = - torch::executor::native::quantize_per_tensor_tensor_args_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Build Vulkan quantize_per_tensor.tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - // Add scale and zero_point as tensor inputs (buffer storage, width packed) - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.tensor") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Run Vulkan quantize_per_tensor.tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - // For quantized types, we need to compare the actual integer values - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - -128, // quant_min - 127, // quant_max - at::kFloat, // input dtype - at::kChar); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4, 12}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kFloat, // input dtype - at::kByte); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_int32) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3}, // input sizes - 0.01, // scale - 12, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, // input dtype - at::kInt); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_half_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {3, 4}, // input sizes - 0.3, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kHalf, // input dtype - at::kByte); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_double_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.03, // scale - -2, // zero_point - -128, // quant_min - 127, // quant_max - at::kDouble, // input dtype - at::kChar); // output dtype -} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index b9386f92772..dae2eddf8b2 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -177,33 +177,6 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/tensor:tensor", ] ) - define_test_targets( - "quantize_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_quantize", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) - define_test_targets( - "dequantize_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_dequantize", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) - define_test_targets( - "choose_qparams_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_choose_qparams", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) define_test_targets( "quantized_linear_test", extra_deps = [ diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 26371bc41ff..49419a50399 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -363,7 +363,7 @@ def generate_suite_cpp(self) -> str: static_cast(indices[0].size())}}; // Flatten indices as from_blob reads garbage otherwise. - std::vector acc; + std::vector acc; for (auto& vec: indices) {{ acc.insert(acc.end(), vec.begin(), vec.end()); }} @@ -380,7 +380,7 @@ def generate_suite_cpp(self) -> str: static_cast(indices[0][0].size())}}; // Flatten indices as from_blob reads garbage otherwise. - std::vector acc; + std::vector acc; for (auto& v: indices) {{ for (auto& vv: v) {{ acc.insert(acc.end(), vv.begin(), vv.end()); diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f92cea64767..03a3263c293 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -11,20 +11,14 @@ from typing import Tuple import executorch.backends.vulkan.test.utils as test_utils - import torch - from executorch.backends.transforms.convert_dtype_pass import I64toI32 - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner - from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) - from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -36,11 +30,8 @@ ) from executorch.extension.pytree import tree_flatten from torch.export import Dim, export, ExportedProgram - from torchao.quantization.granularity import PerGroup - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - from torchao.quantization.pt2e.quantizer import Quantizer from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass @@ -69,9 +60,6 @@ def lower_module( edge_program = to_edge_transform_and_lower( program, compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], ) @@ -1969,102 +1957,6 @@ def forward(self, x): sample_inputs, ) - def test_vulkan_backend_full_quantization_workflow(self): - class FullQuantizationWorkflowModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - # Step 1: Choose quantization parameters per tensor - scale, zero_point = ( - torch.ops.quantized_decomposed.choose_qparams.tensor( - x, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - eps=1e-5, - dtype=torch.int32, - ) - ) - - # Step 2: Quantize using the calculated parameters - quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( - x, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - - # Step 3: Dequantize back to float - dequantized = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( - quantized, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - ) - - return dequantized - - full_workflow_module = FullQuantizationWorkflowModule() - sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),) - - # Use higher tolerance since quantization introduces some error - self.lower_module_and_test_output( - full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 - ) - - def test_vulkan_backend_full_per_token_quantization_workflow(self): - class FullPerTokenQuantizationWorkflowModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - # Step 1: Choose quantization parameters per token - scale, zero_point = ( - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( - x, - dtype=torch.int32, - ) - ) - - # Step 2: Quantize using the calculated parameters per token - quantized = torch.ops.quantized_decomposed.quantize_per_token.default( - x, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - - # Step 3: Dequantize back to float per token - dequantized = ( - torch.ops.quantized_decomposed.dequantize_per_token.default( - quantized, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - output_dtype=torch.float32, - ) - ) - - return dequantized - - full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule() - sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),) - - # Use higher tolerance since quantization introduces some error - self.lower_module_and_test_output( - full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 - ) - def test_vulkan_backend_different_required_reprs(self): class ComplexModule(torch.nn.Module): """ diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index 90edc094ec7..6d3fff452f8 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -8,18 +8,14 @@ import logging from collections import OrderedDict from copy import deepcopy - from enum import auto, Enum from typing import Any, List, Optional, Tuple import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, @@ -36,7 +32,6 @@ ) from executorch.extension.pytree import tree_flatten from torch.export import export - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -275,16 +270,25 @@ def check_outputs_equal( ) return result else: + result = True for i in range(len(ref_output)): - if not torch.allclose( - model_output[i], ref_output[i], atol=atol, rtol=rtol - ): - print(f"\n=== Output {i} comparison failed ===") - print_tensor_comparison_errors( - model_output[i], ref_output[i], atol, rtol - ) - return False - return True + if isinstance(ref_output[i], torch.Tensor): + if not torch.allclose( + model_output[i], ref_output[i], atol=atol, rtol=rtol + ): + print(f"\n=== Output {i} comparison failed ===") + print_tensor_comparison_errors( + model_output[i], ref_output[i], atol, rtol + ) + result = False + elif isinstance(ref_output[i], int): + if not model_output[i] == ref_output[i]: + print(f"\n=== Output {i} comparison failed ===") + print(f"{model_output[i]} vs {ref_output[[i]]}") + result = False + else: + print(f"WARNING: Output {i} has type {type(ref_output[i])}") + return result else: # If one output, eager returns tensor while executor tuple of size 1 result = torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) @@ -326,7 +330,7 @@ def run_and_check_output( model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) # Generate reference outputs using the reference model - ref_output = reference_model(*sample_inputs) + ref_output, _ = tree_flatten(reference_model(*sample_inputs)) # Check if outputs are equal return check_outputs_equal( @@ -805,3 +809,26 @@ def find_bad_operators( "all_operators": all_operators, "test_count": test_count, } + + +def make_indent(indent_level): + indent_str = "" + for _ in range(indent_level): + indent_str += " " + return indent_str + + +def print_output(outputs, n: int = 0, indent_level: int = 0): + if isinstance(outputs, (list, tuple)): + print(f"{make_indent(indent_level)}output_{n} = {type(outputs)}") + new_indent_level = indent_level + 2 + for n, test_out in enumerate(outputs): + print_output(test_out, n, new_indent_level) + elif isinstance(outputs, torch.Tensor): + print( + f"{make_indent(indent_level)}output_{n} = test_utils.random_uniform_tensor({outputs.shape}, low={outputs.min().item()}, high={outputs.max().item()}, dtype={outputs.dtype})" + ) + elif isinstance(outputs, int): + print(f"{make_indent(indent_level)}output_{n} = {outputs}") + else: + print(f"{make_indent(indent_level)}output_{n} = {type(outputs)}") diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 7e3d957afdb..7dd3bb84588 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1911,413 +1911,6 @@ TEST(VulkanComputeGraphTest, test_clone) { } } -TEST(VulkanComputeGraphTest, test_etvk_copy_offset_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 6; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - // Notice that copy_node operates on in texture's x, y, z dimension. In the - // comment, we provide the cooresponding coordinate in nchw. - - // src_offset is (n=0, c=4, h=1, w=1) - ValueRef src_offset_ref = graph.add_scalar_list({1, 1, 1}); - - // dst_offset is (n=1, c=8, h=2, w=0) in nchw coordinate - // Argument is {x, y, z}. - // x = 0 since w = 0 - // y = 2 since h = 2 - // z = c / 4 + 2 since - // 1. there c/4 planes per batch, n=1 means we are on the first batch; - // 2. +2 because c = 8, with channel packing it means two texels. - ValueRef dst_offset_ref = graph.add_scalar_list({0, 2, c / 4 + 2}); - - // range is (n=1, c=8, h=2, w=4) - // Argument is {x, y, z}. - // x = 4 since w = 4 - // y = 2 since h = 2 - // z = 2 since we are only copying 8 channels, hence 2 texel. n = 1 can be a - // bit misleading here, since it gives the impression that we are copying the - // entire channel. However, remember when we copy, we are trying to - // dst[dst_offset:dst_offset + range] = src[src_offset:src_offset + range], - // range must be non zero. - ValueRef range_ref = graph.add_scalar_list({4, 2, 2}); - - auto copyFn = VK_GET_OP_FN("etvk.copy_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0.0f, /*iota = */ true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - // We will examine the results in the dst_range - // The value in the cooresponding coordinate should match between the source - // and destination tensor. We loop thru the range, calculate both the src and - // dst index using the offsets, and compare the values in the extracted - // vector. They should match. - int n_idx = 0; - // at each nested loop, index range from dst_offset to dst_offset + range - - for (int c_idx = 0; c_idx < 8; c_idx++) { - for (int h_idx = 0; h_idx < 2; h_idx++) { - for (int w_idx = 0; w_idx < 4; w_idx++) { - auto dst_idx = - get_buf_idx(graph, out, {n_idx + 1, c_idx + 8, h_idx + 2, w_idx}); - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + 4, h_idx + 1, w_idx + 1}); - - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } -} - -TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 2; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - int64_t src_offset = 2; - int64_t dst_offset = 3; - int64_t range = 7; - - ValueRef src_offset_ref = graph.add_scalar(src_offset); - ValueRef dst_offset_ref = graph.add_scalar(dst_offset); - ValueRef range_ref = graph.add_scalar(range); - - auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0.0f, true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = 0; c_idx < range; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + src_offset, h_idx, w_idx}); - auto dst_idx = get_buf_idx( - graph, out, {n_idx, c_idx + dst_offset, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } - } -} - -TEST( - VulkanComputeGraphTest, - DISABLED_test_etvk_copy_channel_offset_node_clean_boundary) { - // Tricky part for channel copy is handling the boundary across multiple copy. - // For example, when we concat two [3, 1, 1] nchw-tensors along the channel - // dimension, due to channel packing, elements from different source texel - // will be packed into same destination texel at the boundaries. - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 2; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef zero = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - IOValueRef b = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); - - // Make sure entire out tensor is zeroed. The zero tensor will be filled with - // zero later. - copyFn( - graph, - {zero.value, - graph.add_scalar(c), - graph.add_scalar(0), - graph.add_scalar(0), - out.value}); - - int64_t a_src_offset = 0; - int64_t a_dst_offset = 2; - int64_t a_range = 5; - // a will write to channge [2, 7) - copyFn( - graph, - {a.value, - graph.add_scalar(a_range), - graph.add_scalar(a_src_offset), - graph.add_scalar(a_dst_offset), - out.value}); - - // b will write to channel [6, 11) - // Intentional for b to override channel=6 - int64_t b_src_offset = 0; - int64_t b_dst_offset = 6; - int64_t b_range = 5; - - copyFn( - graph, - {b.value, - graph.add_scalar(b_range), - graph.add_scalar(b_src_offset), - graph.add_scalar(b_dst_offset), - out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - float a_value = 1.0f; - float b_value = 2.0f; - float zero_value = 0.0f; - fill_vtensor(graph, a, a_value); - fill_vtensor(graph, b, b_value); - fill_vtensor(graph, zero, zero_value); - - graph.execute(); - - EXTRACT_TENSOR(out); - - for (int n_idx = 0; n_idx < n; n_idx++) { - // c_idx only up to a_range-1 because the expected overwrite by b - for (int c_idx = a_dst_offset; c_idx < a_dst_offset + a_range - 1; - c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == a_value); - } - } - } - } - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = b_dst_offset; c_idx < b_dst_offset + b_range; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == b_value); - } - } - } - } - - // Also verify that data before a_dst_offset and after b_dst_offset + b_range - // are untouched. - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = 0; c_idx < a_dst_offset; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == zero_value); - } - } - } - } - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = b_dst_offset + b_range; c_idx < c; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto dst_idx = get_buf_idx(graph, out, {n_idx, c_idx, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == zero_value); - } - } - } - } -} - -TEST(VulkanComputeGraphTest, test_etvk_copy_offset_int_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 6; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kInt, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kInt, memory_layout); - - // Notice that copy_node operates on in texture's x, y, z dimension. In the - // comment, we provide the cooresponding coordinate in nchw. - - // src_offset is (n=0, c=4, h=1, w=1) - ValueRef src_offset_ref = graph.add_scalar_list({1, 1, 1}); - - // dst_offset is (n=1, c=8, h=2, w=0) in nchw coordinate - // Argument is {x, y, z}. - // x = 0 since w = 0 - // y = 2 since h = 2 - // z = c / 4 + 2 since - // 1. there c/4 planes per batch, n=1 means we are on the first batch; - // 2. +2 because c = 8, with channel packing it means two texels. - ValueRef dst_offset_ref = graph.add_scalar_list({0, 2, c / 4 + 2}); - - // range is (n=1, c=8, h=2, w=4) - // Argument is {x, y, z}. - // x = 4 since w = 4 - // y = 2 since h = 2 - // z = 2 since we are only copying 8 channels, hence 2 texel. n = 1 can be a - // bit misleading here, since it gives the impression that we are copying the - // entire channel. However, remember when we copy, we are trying to - // dst[dst_offset:dst_offset + range] = src[src_offset:src_offset + range], - // range must be non zero. - ValueRef range_ref = graph.add_scalar_list({4, 2, 2}); - - auto copyFn = VK_GET_OP_FN("etvk.copy_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0, /*iota = */ true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - // We will examine the results in the dst_range - // The value in the cooresponding coordinate should match between the source - // and destination tensor. We loop thru the range, calculate both the src and - // dst index using the offsets, and compare the values in the extracted - // vector. They should match. - int n_idx = 0; - // at each nested loop, index range from dst_offset to dst_offset + range - - for (int c_idx = 0; c_idx < 8; c_idx++) { - for (int h_idx = 0; h_idx < 2; h_idx++) { - for (int w_idx = 0; w_idx < 4; w_idx++) { - auto dst_idx = - get_buf_idx(graph, out, {n_idx + 1, c_idx + 8, h_idx + 2, w_idx}); - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + 4, h_idx + 1, w_idx + 1}); - - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } -} - -TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_int_node) { - GraphConfig config; - ComputeGraph graph(config); - - int64_t n = 2; - int64_t c = 12; - int64_t h = 4; - int64_t w = 8; - utils::GPUMemoryLayout memory_layout = - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED; - - std::vector size = {n, c, h, w}; - - IOValueRef a = graph.add_input_tensor(size, vkapi::kFloat, memory_layout); - - IOValueRef out = {}; - out.value = graph.add_tensor(size, vkapi::kFloat, memory_layout); - - int64_t src_offset = 2; - int64_t dst_offset = 3; - int64_t range = 7; - - ValueRef src_offset_ref = graph.add_scalar(src_offset); - ValueRef dst_offset_ref = graph.add_scalar(dst_offset); - ValueRef range_ref = graph.add_scalar(range); - - auto copyFn = VK_GET_OP_FN("etvk.copy_channel_offset"); - copyFn( - graph, {a.value, range_ref, src_offset_ref, dst_offset_ref, out.value}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.prepack(); - - fill_vtensor(graph, a, 0.0f, true); - - graph.execute(); - - EXTRACT_TENSOR(out); - EXTRACT_TENSOR(a); - - for (int n_idx = 0; n_idx < n; n_idx++) { - for (int c_idx = 0; c_idx < range; c_idx++) { - for (int h_idx = 0; h_idx < h; h_idx++) { - for (int w_idx = 0; w_idx < w; w_idx++) { - auto src_idx = - get_buf_idx(graph, a, {n_idx, c_idx + src_offset, h_idx, w_idx}); - auto dst_idx = get_buf_idx( - graph, out, {n_idx, c_idx + dst_offset, h_idx, w_idx}); - EXPECT_TRUE(data_out[dst_idx] == data_a[src_idx]); - } - } - } - } -} - TEST(VulkanComputeGraphTest, test_view_change_packing) { std::vector> layout_pairs = { diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 00147dab2c3..2ca2ddf19b7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -8,26 +8,18 @@ from typing import Any, List, Optional, Set, Tuple, Union import torch - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) - from executorch.exir.dialects.edge._ops import EdgeOpOverload - from executorch.exir.tensor import TensorSpec - from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param - from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter - from torch.export import ExportedProgram - from torch.export.exported_program import InputKind from torch.export.graph_signature import TensorArgument @@ -373,6 +365,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: return None +def node_has_target(node: Any, target: str): + if not hasattr(node, "target"): + return False + + if isinstance(node.target, str): + return node.target == target + elif hasattr(node.target, "name"): + return node.target.name() == target + + return False + + ## ## Memory Layout, Storage Type Determination ## @@ -387,10 +391,23 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]: VkStorageType.TEXTURE_3D, } +# Memory layouts available to non-quantized tensors all_memory_layouts: Set[VkMemoryLayout] = { VkMemoryLayout.TENSOR_WIDTH_PACKED, VkMemoryLayout.TENSOR_HEIGHT_PACKED, VkMemoryLayout.TENSOR_CHANNELS_PACKED, +} + +# Memory layouts available to quantized tensors +all_quantized_memory_layouts: Set[VkMemoryLayout] = { + VkMemoryLayout.PACKED_INT8_4W4C, + VkMemoryLayout.PACKED_INT8_4H4W, +} + +universal_memory_layout_set: Set[VkMemoryLayout] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_HEIGHT_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, VkMemoryLayout.PACKED_INT8_4W4C, VkMemoryLayout.PACKED_INT8_4H4W, } @@ -749,7 +766,7 @@ def make_filtered_tensor_repset( ## Convenience TensorRepSet definitions -PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) +# Only includes memory layouts that can be used by non-quantized tensors CONTIGUOUS_ANY = TensorRepSet( {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} @@ -760,11 +777,28 @@ def make_filtered_tensor_repset( HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED}) CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) +CHANNELS_PACKED_ANY = TensorRepSet( + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + +CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) ANY_BUFFER = TensorRepSet(all_memory_layouts, set()) - ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts) + +# Only includes memory layouts that can be used by quantized tensors + +PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set()) + +# Special use RepSets + NO_STORAGE = TensorRepSet(set(), set()) +ALL_STORAGES_REPSET = TensorRepSet( + universal_memory_layout_set, universal_memory_layout_set +) class TensorRepSetList: @@ -888,19 +922,19 @@ def __init__( # noqa: C901 # Now, go through the arguments of the operator and create a filtered repset # for each based on the actual tensor value. args_repset_list = TensorRepSetList([]) - common_arg_repset = ANY_STORAGE + common_arg_repset = ALL_STORAGES_REPSET for i, arg_node in enumerate(op_node.args): arg_repset = inputs_repsets[i] - # Use ANY_STORAGE for non-tensor nodes so they don't cause the op repsets to - # appear empty + # Use ALL_STORAGES_REPSET for non-tensor nodes so they don't cause the op + # repsets to appear empty if not is_tensor_arg_node(arg_node): - args_repset_list.append(ANY_STORAGE) + args_repset_list.append(ALL_STORAGES_REPSET) # NO_STORAGE is used to denote that an input is either a non tensor arg or # a weight tensor that is not prepacked. Similar to the above, use - # ANY_STORAGE in this case. + # ALL_STORAGES_REPSET in this case. elif arg_repset.is_empty(): - args_repset_list.append(ANY_STORAGE) + args_repset_list.append(ALL_STORAGES_REPSET) else: assert not arg_repset.is_empty() @@ -913,7 +947,7 @@ def __init__( # noqa: C901 # Repeat for output tensors. outs_repset_list = TensorRepSetList([]) - common_out_repset = ANY_STORAGE + common_out_repset = ALL_STORAGES_REPSET if num_tensors_in_node(op_node) == 1: common_out_repset = make_filtered_tensor_repset( op_node.meta["val"], outputs_repsets[0], texture_limits @@ -1086,6 +1120,25 @@ def try_constrain_with_arg_repset( self.assert_sync_contraints() return True + def try_constrain_with_out_repset(self, repset: TensorRepSet): + # Skip for operators that must synchronize the input and output representations + # or operators that have more than one output repset + if self.sync_primary_io_repr or len(self.outs_repset_list) > 1: + return False + + out_current_repset = self.outs_repset_list[0] + + if out_current_repset == repset: + return False + + if not out_current_repset.any_in_common(repset): + return False + + self.outs_repset_list[0] = out_current_repset.make_intersect(repset) + + self.assert_sync_contraints() + return True + def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]: """ For each tensor participating in the op, pick a representation for it among the diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 876f7fa8900..3ccbdc8ab85 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,12 +6,11 @@ # pyre-strict +import copy from functools import partial - from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils - from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform @@ -22,15 +21,12 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, - RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, - ReplaceQDQPass, SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -40,7 +36,6 @@ serialize_vulkan_graph, ) from executorch.backends.xnnpack._passes import FuseBatchNormPass - from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -48,18 +43,12 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder - from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite from executorch.exir.pass_base import ExportPass, PassBase - from executorch.exir.passes import MemoryPlanningPass, SpecPropPass - from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass - from executorch.exir.program._program import _transform - from torch._export.verifier import Verifier - from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) @@ -139,15 +128,21 @@ def preprocess( # noqa: C901 module_compile_spec: List[CompileSpec], ) -> PreprocessResult: compile_options = parse_compile_spec(module_compile_spec) - limits_x = compile_options.get( - "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0] - ) - limits_y = compile_options.get( - "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1] - ) - limits_z = compile_options.get( - "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2] - ) + + default_texture_limits = copy.deepcopy(utils.DEFAULT_TEXTURE_LIMITS) + # 2048 is the typical limit value for 3D textures, but mobile GPUs often support + # 16384. Since the Vulkan delegate primarily targets mobile GPUs at the moment, + # 16394 is the default texture limit used. This option is provided as a + # convenient way to switch to using a limit of 2048 for image textures which + # will be compatible with most GPUs. + if compile_options.get("small_texture_limits", False): + default_texture_limits[0] = 2048 + default_texture_limits[1] = 2048 + default_texture_limits[2] = 2048 + + limits_x = compile_options.get("texture_limits_x", default_texture_limits[0]) + limits_y = compile_options.get("texture_limits_y", default_texture_limits[1]) + limits_z = compile_options.get("texture_limits_z", default_texture_limits[2]) texture_limits = (limits_x, limits_y, limits_z) default_storage_type = compile_options.get( @@ -173,7 +168,6 @@ def preprocess( # noqa: C901 AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), - ReplaceQDQPass(), FoldQDQPass(), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), @@ -193,9 +187,6 @@ def preprocess( # noqa: C901 program, [ RemoveAssertsTransform(), - # Since this pass may replace a scalar argument with a tensor argument, - # this pass may result in a non ATen compliant graph structure. - RemoveLocalScalarDenseOpsTransform(), insert_prepack_nodes, ], ) @@ -213,28 +204,33 @@ def preprocess( # noqa: C901 texture_limits, default_storage_type=default_storage_type, default_memory_layout=default_memory_layout, + force_fp16=force_fp16, ), ], ) # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. - greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False) - mem_planning_suite = MemoryPlanningAlgorithmSuite( - algo_list=[greedy_memory_planning] - ) - # This is a workaround to allow the memory planning pass to work without having - # to first apply ToOutVarPass(). See the `greedy()` function in - # `exir.memory_planning`; if this attribute isn't set, assertions in - # `collect_spec_from_nodes()` will fail. - program.graph_module.encounter_to_out_var_failure = True - program = apply_passes( - program, - [ - ConstraintBasedSymShapeEvalPass(), - MemoryPlanningPass(memory_planning_algo=mem_planning_suite), - ], - ) + final_passes = [ + ConstraintBasedSymShapeEvalPass(), + ] + if not compile_options.get("skip_memory_planning", False): + greedy_memory_planning = partial( + greedy, allow_overlapping_allocations=False + ) + mem_planning_suite = MemoryPlanningAlgorithmSuite( + algo_list=[greedy_memory_planning] + ) + # This is a workaround to allow the memory planning pass to work without having + # to first apply ToOutVarPass(). See the `greedy()` function in + # `exir.memory_planning`; if this attribute isn't set, assertions in + # `collect_spec_from_nodes()` will fail. + program.graph_module.encounter_to_out_var_failure = True + final_passes.append( + MemoryPlanningPass(memory_planning_algo=mem_planning_suite) + ) + + program = apply_passes(program, final_passes) graph_builder = VkGraphBuilder( program, diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index 6f7b13d8026..4977ad08936 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -8,6 +8,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/backends/transforms:addmm_mm_to_linear", + "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/transforms:lib", "//executorch/backends/xnnpack/partition:partitioner_graphs", "//executorch/backends/xnnpack/serialization:xnnpack_schema", diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index c48896b3d81..4992d7a4abd 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -4,8 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import List, Optional, Type +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform + from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( @@ -42,6 +46,11 @@ from torch.export import ExportedProgram +class XNNPACKRemoveCloneOpsTransform(RemoveCloneOpsTransform): + def __init__(self): + super().__init__(preserve_input_output_copies=True) + + class XNNPACKPassManager: def __init__( self, @@ -58,6 +67,7 @@ def __init__( if not passes: # All the XNNPACK passes self.passes = [ + XNNPACKRemoveCloneOpsTransform, # TODO - remove this pass once we have a better support for dim_order ops lowering DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index 93424b1c84d..02a46a6fc47 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from . import ( # noqa node_visitor, op_abs, @@ -14,6 +16,7 @@ op_cat, op_ceiling, op_clamp, + op_clone, op_conv2d, op_div, op_dynamic_dequantize_ops, diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 68226644859..4643ada9336 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -275,7 +275,7 @@ def get_per_channel_dtype( return dtype def get_quant_params( - self, quant_params: QuantParams, xnn_graph: XNNGraph + self, quant_params: QuantParams, xnn_graph: XNNGraph, external_tag: str = None ) -> XNNQuantParams: if quant_params.per_channel: scale = cast(torch.Tensor, quant_params.scale) @@ -291,13 +291,18 @@ def get_quant_params( ctypes.POINTER(ctypes.c_char * num_bytes), ).contents scale_name = hashlib.sha256(bytes(scale_array)).hexdigest() + scale_name = "scale_" + scale_name xnn_graph.constant_data.append( ConstantDataOffset( offset=UINT64_MAX, size=num_bytes, named_key=scale_name ) ) + if external_tag is not None: + logging.info( + f"Adding constant data with name, key {scale_name} and external_tag {external_tag} to named_data_store" + ) self._named_data_store.add_named_data( - scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT + scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT, external_tag ) if quant_params.per_channel_group: @@ -470,13 +475,19 @@ def define_tensor( # noqa: C901 assert f"Unsupported weight per channel quantization axis for depthwise conv2d / conv_transpose2d : {quant_params.axis}, expecting 0 / 1." # Serialize tensor value + custom_meta = tensor.meta.get("custom", None) + external_tag = ( + custom_meta.get("delegate_constant_tag", None) if custom_meta else None + ) ser_val = ( XValue(xvalue_union=tvalue) if quant_params is None else XValue( xvalue_union=XNNQuantizedTensorValue( tensor_value=tvalue, - quant_params=self.get_quant_params(quant_params, xnn_graph), + quant_params=self.get_quant_params( + quant_params, xnn_graph, external_tag + ), ) ) ) @@ -614,7 +625,7 @@ def get_serialized_buffer_index( f"Serializing constant data node {tensor} but tensor value has no bytes", ) sha256_hash = hashlib.sha256(bytes(array)) - named_key = sha256_hash.hexdigest() + named_key = tensor.name + "_" + sha256_hash.hexdigest() size = const_val.untyped_storage().nbytes() xnn_graph.constant_data.append( @@ -626,7 +637,6 @@ def get_serialized_buffer_index( custom_meta.get("delegate_constant_tag", None) if custom_meta else None ) if external_tag is not None: - external_tag = custom_meta.get("delegate_constant_tag", None) logging.info( f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" ) diff --git a/backends/xnnpack/operators/op_clone.py b/backends/xnnpack/operators/op_clone.py new file mode 100644 index 00000000000..e4ddf187ecc --- /dev/null +++ b/backends/xnnpack/operators/op_clone.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNCopy, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class CloneVisitor(NodeVisitor): + target = "aten.clone.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # Sanity check that the input and output dim order are the same. We don't + # handle dim order conversions yet. + dim_order = node.kwargs.get("dim_order", None) + input_meta = node.args[0].meta["val"] + assert dim_order is None or list(input_meta.dim_order()) == dim_order + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNCopy( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 86baba3e3f7..5427b3a7838 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe from typing import List, Type @@ -22,6 +23,7 @@ CatConfig, CeilConfig, ClampConfig, + CloneDimOrderConfig, ConstantPadConfig, DeQuantizedPerTensorConfig, DivConfig, @@ -77,6 +79,7 @@ BMMConfig, CatConfig, CeilConfig, + CloneDimOrderConfig, ConstantPadConfig, ConvolutionConfig, ClampConfig, diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index f65f9cb3398..d025c8e6029 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -458,9 +458,7 @@ def get_deps( a bool indicating if the deps are valid and a list of all the dep nodes. This handles the src partition for """ - if self.src_partitions is None: - # Cache src partitions so we don't have to recompute them every time - self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) + self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) # src_partition is None if node is not in source partition, # otherwise gives us the linear source partition it belongs to diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 06024c632c9..434fce1d73a 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -643,3 +643,25 @@ class SinConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + + +class CloneDimOrderConfig(GenericNodePartitionerConfig): + target_name = "_clone_dim_order.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + # Only partition no-op _clone_dim_order nodes (output dim order = input). + # We can relax this in the future. + # This is also a conservative check and doesn't consider ambiguity. + dim_order = node.kwargs.get("dim_order", None) + input_meta = node.args[0].meta["val"] + if dim_order is not None and list(input_meta.dim_order()) != dim_order: + why(node, reason="Only dim-order preserving clones are supported.") + return False + + return True diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 3e697566ce5..ec937a64744 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1459,6 +1459,34 @@ Error defineBatchMatrixMultiplyNode( return Error::Ok; } +/* + * Defines a copy node in the XNN subgraph. + */ +Error defineCopyNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const NodePtr node, + const fb_xnnpack::XNNGraph* graph) noexcept { + MAYBE_UNUSED(graph); + + auto graph_node = node->xnode_union_as_XNNCopy(); + + xnn_status status = xnn_define_copy( + subgraph_ptr, + remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create copy node %i with code: %s", + node->debug_handle(), + xnn_status_to_string(status)); + + return Error::Ok; +} + /* Returns not Implemented Error code. This function is meant to be called when the compiler encountes a XNodeType from the flatbuffer @@ -1763,6 +1791,7 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { _DEFINE(Concatenate5) _DEFINE(StaticSlice) _DEFINE(BatchMatrixMultiply) + _DEFINE(Copy) case fb_xnnpack::XNodeUnion::NONE: default: // Adding here as a catch all, just in case return &defineNotImplementedNode; diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index 239f92d899e..939bbd7a82f 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -157,6 +157,7 @@ union XNodeUnion { XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, XNNSin: _XNNNode1x1, + XNNCopy: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 92a61c5537b..08d9184b9f5 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -153,6 +153,7 @@ union XNodeUnion { XNNTanh: _XNNNode1x1, XNNExp: _XNNNode1x1, XNNSin: _XNNNode1x1, + XNNCopy: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 2b3f8e74202..872056fa82e 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -352,6 +352,11 @@ class XNNSin(XNNNode1x1): pass +@dataclass +class XNNCopy(XNNNode1x1): + pass + + @dataclass class XNNScaledDotProductAttention: query_id: int @@ -409,6 +414,7 @@ class XNNScaledDotProductAttention: XNNTanh, XNNExp, XNNSin, + XNNCopy, ] diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 5f3581b6aeb..d20a6003f3f 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -113,3 +113,14 @@ runtime.python_test( "//executorch/examples/xnnpack:models", # @manual ], ) + +runtime.python_test( + name = "test_xnnpack_partitioner", + srcs = ["test_xnnpack_partitioner.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", + ], +) diff --git a/backends/xnnpack/test/ops/test_clone.py b/backends/xnnpack/test/ops/test_clone.py new file mode 100644 index 00000000000..0396b9b2bea --- /dev/null +++ b/backends/xnnpack/test/ops/test_clone.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestClone(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Clone(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.clone(x) + return z + + class CloneWithMemoryFormat(torch.nn.Module): + def __init__(self, memory_format): + super().__init__() + self.memory_format = memory_format + + def forward(self, x): + z = torch.clone(x, memory_format=self.memory_format) + return z + + def _test_clone_partitioned(self, inputs): + """Test that dim-order preserving clones are partitioned (removed)""" + ( + Tester(self.Clone(), inputs) + .export() + .check_count({"torch.ops.aten.clone.default": 1}) + .dump_artifact() + .to_edge_transform_and_lower() + .dump_artifact() + .check_not( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_clone(self): + """Test FP16 clone - should be partitioned""" + inputs = (torch.randn(2, 3, 4, 5).to(torch.float16),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone(self): + """Test FP32 clone - should be partitioned""" + inputs = (torch.randn(2, 3, 4, 5),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_2d(self): + """Test FP32 clone with 2D tensor - should be partitioned""" + inputs = (torch.randn(10, 20),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_3d(self): + """Test FP32 clone with 3D tensor - should be partitioned""" + inputs = (torch.randn(2, 3, 4),) + self._test_clone_partitioned(inputs) + + def test_fp32_clone_with_contiguous_format(self): + """Test FP32 clone with contiguous memory format - should be partitioned""" + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(self.CloneWithMemoryFormat(torch.contiguous_format), inputs) + .export() + .to_edge_transform_and_lower() + .dump_artifact() + .check_not( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_clone_with_channels_last_not_partitioned(self): + """Test FP32 clone with channels_last memory format - should NOT be partitioned""" + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(self.CloneWithMemoryFormat(torch.channels_last), inputs) + .export() + .to_edge_transform_and_lower() + # Clone with channels_last changes dim order, so should NOT be delegated + .check( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_clone_channels_last_to_contiguous_not_partitioned(self): + """Test clone from channels_last to contiguous - should NOT be partitioned""" + + class CloneChannelsLastToContiguous(torch.nn.Module): + def forward(self, x): + # Start with channels_last input + y = x.to(memory_format=torch.channels_last) + # Clone back to contiguous (changes dim order) + z = torch.clone(y, memory_format=torch.contiguous_format) + return z + + inputs = (torch.randn(1, 3, 4, 4),) + ( + Tester(CloneChannelsLastToContiguous(), inputs) + .export() + .to_edge_transform_and_lower() + .dump_artifact() + # Clone that changes dim order should NOT be delegated + .check( + [ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/ops/test_conv1d.py b/backends/xnnpack/test/ops/test_conv1d.py index 036500b29d5..35d9bced512 100644 --- a/backends/xnnpack/test/ops/test_conv1d.py +++ b/backends/xnnpack/test/ops/test_conv1d.py @@ -126,7 +126,9 @@ def _test_conv1d( # quantized operators to be loaded and we don't want to do that in the test. if not skip_to_executorch: tester.to_executorch().serialize().run_method_and_compare_outputs( - num_runs=10, atol=0.02, rtol=0.02 + num_runs=10, + atol=0.04 if quantized else 1e-03, + rtol=0.02 if quantized else 1e-03, ) def test_fp16_conv1d(self): diff --git a/backends/xnnpack/test/test_xnnpack_partitioner.py b/backends/xnnpack/test/test_xnnpack_partitioner.py index 8cd9eb92d56..894fab4098f 100644 --- a/backends/xnnpack/test/test_xnnpack_partitioner.py +++ b/backends/xnnpack/test/test_xnnpack_partitioner.py @@ -9,8 +9,13 @@ import unittest import torch +import torch.nn.functional as F + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge, to_edge_transform_and_lower +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) from torch.export import export @@ -82,3 +87,77 @@ def test_no_warning_for_to_edge_transform_and_lower_workflow(self): log_contents = log_capture_string.getvalue() self.assertNotIn("DEPRECATION WARNING", log_contents) + + def test_multi_method_partitioning_with_shared_weights(self): + """ + Test that multi-method models with shared weights are correctly partitioned. + Verify that: + 1. Both methods are fully lowered to XNNPACK. + 2. Constants are not duplicated between named data and constant buffers. + 3. Program executes correctly. + """ + + class MultiMethodModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(8, 16) + self.linear2 = torch.nn.Linear(16, 8) + + def forward(self, x): + return self.linear2(F.sigmoid(self.linear(x))) + + def forward_2(self, x): + return self.linear2(F.relu(self.linear(x))) + + def example_inputs(self): + return (torch.randn(1, 8),) + + model = MultiMethodModel() + + # Get eager reference output. + example_inputs = model.example_inputs() + with torch.no_grad(): + fwd1_eager = model.forward(*example_inputs) + fwd2_eager = model.forward_2(*example_inputs) + + # Export both methods + ep_fwd = export(model, model.example_inputs(), strict=True) + # Patch the forward, as export only traces the 'forward' method. + model.forward = model.forward_2 + ep_fwd_2 = export(model, model.example_inputs(), strict=True) + + # Convert to edge and lower to executorch + edge = to_edge({"forward": ep_fwd, "forward_2": ep_fwd_2}) + lowered = edge.to_backend(XnnpackPartitioner(force_fp32_dynamic_linear=True)) + executorch = lowered.to_executorch() + + # Check that graph is fully delegated. + nodes_1 = list(lowered._edge_programs["forward"].graph.nodes) + nodes_2 = list(lowered._edge_programs["forward_2"].graph.nodes) + self.assertEqual(len(nodes_1), 5) + self.assertEqual(len(nodes_2), 5) + expected_node_names = [ + "x", + "lowered_module_0", + "executorch_call_delegate", + "getitem", + "output_1", + ] + for n in expected_node_names: + self.assertTrue(any(node.name == n for node in nodes_1)) + self.assertTrue(any(node.name == n for node in nodes_2)) + + # Check that weights are not duplicated. + self.assertEqual(len(executorch._named_data.pte_data), 4) + self.assertEqual(len(executorch._named_data.buffers), 4) + self.assertEqual(len(executorch._named_data.external_data), 0) + + # Check that there are no constant buffers (besides the placeholder). + self.assertEqual(len(executorch._emitter_output.program.constant_buffer), 1) + + # Check for model correctness. + executorch_module = _load_for_executorch_from_buffer(executorch.buffer) + fwd1_et = executorch_module.run_method("forward", example_inputs) + fwd2_et = executorch_module.run_method("forward_2", example_inputs) + self.assertTrue(torch.allclose(fwd1_eager, fwd1_et[0], 1e-3)) + self.assertTrue(torch.allclose(fwd2_eager, fwd2_et[0], 1e-3)) diff --git a/codegen/tools/gen_ops_def.py b/codegen/tools/gen_ops_def.py index aba3f9242ac..98fdab73fd1 100644 --- a/codegen/tools/gen_ops_def.py +++ b/codegen/tools/gen_ops_def.py @@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]: print("Processing model file: ", model_file) with open(model_file, "rb") as f: flatbuffer = f.read() - program = _deserialize_pte_binary(flatbuffer) + program = _deserialize_pte_binary(flatbuffer).program print(f"Program loaded from model file: {model_file}") operators = program.execution_plan[0].operators return operators diff --git a/desktop/README.md b/desktop/README.md index 5a76aeb5fcb..2c00be632e7 100644 --- a/desktop/README.md +++ b/desktop/README.md @@ -13,7 +13,7 @@ ExecuTorch is a lightweight, flexible runtime designed for efficient AI inferenc With increased demand for local inference on consumer desktops and laptops, exemplified by popular runtimes like llama.cpp and MLX, ExecuTorch is now experimenting with CUDA and Metal support. This is achieved by leveraging Inductor compiler technology from PyTorch, specifically using Ahead-of-Time Inductor [AOTI](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) to avoid reinventing the wheel. ## Key Benefits -- **Model Agnostic**: Validated on models such as [Voxtral](../examples/models/voxtral), [Gemma3-4b](../examples/models/gemma3), ResNet, and Whisper (WIP). Theoretically, any model exportable via torch.export is supported. +- **Model Agnostic**: Validated on models such as [Voxtral](../examples/models/voxtral), [Gemma3-4b](../examples/models/gemma3), ResNet, and [Whisper](../examples/models/whisper/README.md). Theoretically, any model exportable via torch.export is supported. - **PyTorch Ecosystem Integration**: Enables workflows for fine-tuning, quantization, and compilation within the PyTorch ecosystem. - **No Python Runtime During Inference**: Ideal for native applications (e.g., written in C++) embedding AI capabilities. - **No libtorch Dependency**: Reduces binary size, making deployment easier for resource-constrained applications. diff --git a/devtools/bundled_program/test/test_bundle_data.py b/devtools/bundled_program/test/test_bundle_data.py index a587a8672e9..9fdeb4a776d 100644 --- a/devtools/bundled_program/test/test_bundle_data.py +++ b/devtools/bundled_program/test/test_bundle_data.py @@ -18,7 +18,7 @@ from executorch.devtools.bundled_program.util.test_util import ( get_common_executorch_program, ) -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary class TestBundle(unittest.TestCase): @@ -72,7 +72,11 @@ def test_bundled_program(self) -> None: self.assertEqual( bundled_program.serialize_to_schema().program, - bytes(_serialize_pte_binary(executorch_program.executorch_program)), + bytes( + _serialize_pte_binary( + pte_file=_PTEFile(program=executorch_program.executorch_program) + ) + ), ) def test_bundled_program_from_pte(self) -> None: diff --git a/devtools/visualization/visualization_utils.py b/devtools/visualization/visualization_utils.py index 6dd0c327048..b76d164b61b 100644 --- a/devtools/visualization/visualization_utils.py +++ b/devtools/visualization/visualization_utils.py @@ -26,7 +26,7 @@ ) except ImportError: print( - "Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh" + "Error: 'model_explorer' is not installed. Install using devtools/install_requirements.sh" ) raise diff --git a/devtools/visualization/visualization_utils_test.py b/devtools/visualization/visualization_utils_test.py index 4f44241518f..0d470a7f359 100644 --- a/devtools/visualization/visualization_utils_test.py +++ b/devtools/visualization/visualization_utils_test.py @@ -24,7 +24,7 @@ from model_explorer.config import ModelExplorerConfig # type: ignore except ImportError: print( - "Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh" + "Error: 'model_explorer' is not installed. Install using devtools/install_requirements.sh" ) raise diff --git a/docs/Makefile b/docs/Makefile index 627358d0387..c4f5e571ff8 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -13,6 +13,12 @@ BUILDDIR = _build html-noplot: $(SPHINXBUILD) -D plot_gallery=0 -b html $(SPHINXOPTS) "$(SOURCEDIR)" "$(BUILDDIR)/html" +html-stable: + # Stable differs from 'make html' in that it shows the release version + # instead of "main (version)" in the docs and version switcher. + # See conf.py for more details. + RELEASE=true $(MAKE) html + help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/android-arm-vgf.md b/docs/source/android-arm-vgf.md index cc39b53e176..51111900add 100644 --- a/docs/source/android-arm-vgf.md +++ b/docs/source/android-arm-vgf.md @@ -1 +1 @@ -```{include} backends-arm-vgf.md +```{include} backends/arm-vgf/arm-vgf-overview.md diff --git a/docs/source/backends-arm-vgf.md b/docs/source/backends-arm-vgf.md deleted file mode 100644 index 97d7bf193e3..00000000000 --- a/docs/source/backends-arm-vgf.md +++ /dev/null @@ -1,204 +0,0 @@ -# Arm® VGF Backend - -The Arm VGF backend is the ExecuTorch solution for lowering PyTorch models to VGF compatible hardware. -It leverages the TOSA operator set and the [ML SDK for Vulkan®](https://github.com/arm/ai-ml-sdk-for-vulkan?tab=readme-ov-file) to produce a .PTE file. -The VGF backend also supports execution from a .PTE file and provides functionality to extract the corresponding VGF file for integration into various applications. - -## Features - -- Wide operator support for delegating large parts of models to the VGF target. -- A quantizer that optimizes quantization for the VGF target. - -## Target Requirements -The target system must include ML SDK for Vulkan and a Vulkan driver with Vulkan API >= 1.3. - -## Development Requirements - -```{tip} -All requirements can be downloaded using `examples/arm/setup.sh --enable-mlsdk-deps --disable-ethos-u-deps` and added to the path using -`source examples/arm/ethos-u-scratch/setup_path.sh` -``` - -For the AOT flow, compilation of a model to `.pte` format using the VGF backend, the requirements are: -- [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR. -- [ML SDK Model Converter](https://github.com/arm/ai-ml-sdk-model-converter) for converting TOSA flatbuffers to VGF files. - -And for building and running your application using the generic executor_runner: -- [Vulkan API](https://www.vulkan.org) should be set up locally for GPU execution support. -- [ML Emulation Layer for Vulkan](https://github.com/arm/ai-ml-emulation-layer-for-vulkan) for testing on Vulkan API. - -## Using the Arm VGF Backend -The [VGF Minimal Example](https://github.com/pytorch/executorch/blob/main/examples/arm/vgf_minimal_example.ipynb) demonstrates how to lower a module using the VGF backend. - -The main configuration point for the lowering is the `VgfCompileSpec` consumed by the partitioner and quantizer. -The full user-facing API is documented below. - -```python -class VgfCompileSpec(tosa_spec: executorch.backends.arm.tosa.specification.TosaSpecification | str | None = None, compiler_flags: list[str] | None = None) -``` -Compile spec for VGF compatible targets. - -Attributes: -- **tosa_spec**: A TosaSpecification, or a string specifying a TosaSpecification. -- **compiler_flags**: Extra compiler flags for converter_backend. - -```python -def VgfCompileSpec.dump_debug_info(self, debug_mode: executorch.backends.arm.common.arm_compile_spec.ArmCompileSpec.DebugMode | None): -``` -Dump debugging information into the intermediates path. - -```python -def VgfCompileSpec.dump_intermediate_artifacts_to(self, output_path: str | None): -``` -Sets a path for dumping intermediate results during lowering such as tosa and pte. - -```python -def VgfCompileSpec.get_intermediate_path(self) -> str | None: -``` -Returns the path for dumping intermediate results during lowering such as tosa and pte. - -```python -def VgfCompileSpec.get_output_format() -> str: -``` -Returns a constant string that is the output format of the class. - - - -### Partitioner API -```python -class VgfPartitioner(compile_spec: executorch.backends.arm.vgf.compile_spec.VgfCompileSpec, additional_checks: Optional[Sequence[torch.fx.passes.operator_support.OperatorSupportBase]] = None) -> None -``` -Partitions subgraphs supported by the Arm Vgf backend. - -Attributes: -- **compile_spec**:List of CompileSpec objects for Vgf backend. -- **additional_checks**: Optional sequence of additional operator support checks. - -```python -def VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported_program.ExportedProgram) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.node.Node], bool]]]: -``` -Returns a list of operator names that should not be decomposed. When these ops are -registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be -guaranteed that the program that the backend receives will not have any of these ops -decomposed. - -Returns: -- **List[torch._ops.OpOverload]**: a list of operator names that should not be decomposed. -- **Optional[Callable[[torch.fx.Node], bool]]]**: an optional callable, acting as a filter, that users can provide - which will be called for each node in the graph that users can use as a filter for certain - nodes that should be continued to be decomposed even though the op they correspond to is - in the list returned by ops_to_not_decompose. - -```python -def VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult: -``` -Returns the input exported program with newly created sub-Modules encapsulating -specific portions of the input "tagged" for delegation. - -The specific implementation is free to decide how existing computation in the -input exported program should be delegated to one or even more than one specific -backends. - -The contract is stringent in that: -* Each node that is intended to be delegated must be tagged -* No change in the original input exported program (ExportedProgram) representation can take -place other than adding sub-Modules for encapsulating existing portions of the -input exported program and the associated metadata for tagging. - -Args: -- **exported_program**: An ExportedProgram in Edge dialect to be partitioned for backend delegation. - -Returns: -- **PartitionResult**: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers. - - - -### Quantizer -The VGF quantizer supports [Post Training Quantization (PT2E)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) -and [Quantization-Aware Training (QAT)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_qat.html) quantization. - -Currently the symmetric `int8` config defined by `executorch.backends.arm.quantizer.arm_quantizer.get_symmetric_quantization_config` is -the main config available to use with the VGF quantizer. - -```python -class VgfQuantizer(compile_spec: 'VgfCompileSpec') -> 'None' -``` -Quantizer supported by the Arm Vgf backend. - -Attributes: -- **compile_spec**: VgfCompileSpec, specifies the compilation configuration. - -```python -def VgfQuantizer.set_global(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': -``` -Set quantization_config for submodules that are not already annotated by name or type filters. - -Args: -- **quantization_config**: Specifies the quantization scheme for the weights and activations - -```python -def VgfQuantizer.set_io(self, quantization_config): -``` -Set quantization_config for input and output nodes. - -Args: -- **quantization_config**: Specifies the quantization scheme for the weights and activations - -```python -def VgfQuantizer.set_module_name(self, module_name: 'str', quantization_config: 'Optional[QuantizationConfig]') -> 'TOSAQuantizer': -``` -Set quantization_config for a submodule with name: `module_name`, for example: -quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator -patterns in the submodule with this module name with the given `quantization_config` - -Args: -- **module_name**: Name of the module to which the quantization_config is set. -- **quantization_config**: Specifies the quantization scheme for the weights and activations. - -Returns: -- **TOSAQuantizer**: The quantizer instance with the updated module name configuration - -```python -def VgfQuantizer.set_module_type(self, module_type: 'Callable', quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': -``` -Set quantization_config for a submodule with type: `module_type`, for example: -quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator -patterns in the submodule with this module type with the given `quantization_config` - -Args: -- **module_type**: Type of module to which the quantization_config is set. -- **quantization_config**: Specifies the quantization scheme for the weights and activations. - -Returns: -- **TOSAQuantizer**: The quantizer instance with the updated module type configuration - -```python -def VgfQuantizer.transform_for_annotation(self, model: 'GraphModule') -> 'GraphModule': -``` -An initial pass for transforming the graph to prepare it for annotation. -Currently transforms scalar values to tensor attributes. - -Args: -- **model**: Module that is transformed. - -Returns: - The transformed model. - - -### Supported Quantization Schemes -The quantization schemes supported by the VGF Backend are: -- 8-bit symmetric weights with 8-bit asymmetric activations (via the PT2E quantization flow). - - Supports both static and dynamic activations - - Supports per-channel and per-tensor schemes - -Weight-only quantization is not currently supported on VGF - -## Runtime Integration - -The VGF backend can use the default ExecuTorch runner. The steps required for building and running it are explained in the previously mentioned [VGF Backend Tutorial](https://docs.pytorch.org/executorch/stable/tutorial-arm-ethos-u.html). -The example application is recommended to use for testing basic functionality of your lowered models, as well as a starting point for developing runtime integrations for your own targets. - -### VGF Adapter for Model Explorer - -The [VGF Adapter for Model Explorer](https://github.com/arm/vgf-adapter-model-explorer) enables visualization of -VGF files and can be useful for debugging. diff --git a/docs/source/backends-nxp.md b/docs/source/backends-nxp.md index f4f7762c769..20dd180fb31 100644 --- a/docs/source/backends-nxp.md +++ b/docs/source/backends-nxp.md @@ -56,16 +56,20 @@ List of Aten operators supported by Neutron quantizer: `reshape`, `view`, `softmax.int`, `sigmoid`, `tanh`, `tanh_` #### Example + +To quantize your model, you can either use the PT2E workflow: ```python import torch from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e # Prepare your model in Aten dialect aten_model = get_model_in_aten_dialect() # Prepare calibration inputs, each tuple is one example, example tuple has items for each model input calibration_inputs: list[tuple[torch.Tensor, ...]] = get_calibration_inputs() -quantizer = NeutronQuantizer() +target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09") +quantizer = NeutronQuantizer(neutron_target_spec) m = prepare_pt2e(aten_model, quantizer) for data in calibration_inputs: @@ -73,6 +77,22 @@ for data in calibration_inputs: m = convert_pt2e(m) ``` +Or you can use the predefined function for post training quantization from NXP backend implementation: +```python +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.quantizer.utils import post_training_quantize + +... + +target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09") +quantized_graph_module = post_training_quantize( + aten_model, + calibration_inputs, + NeutronQuantizer(neutron_target_spec=target_spec), +) +``` + ## Runtime Integration To learn how to run the converted model on the NXP hardware, use one of our example projects on using ExecuTorch runtime from MCUXpresso IDE example projects list. diff --git a/docs/source/backends-overview.md b/docs/source/backends-overview.md index ddb55f2afec..7b50c60f521 100644 --- a/docs/source/backends-overview.md +++ b/docs/source/backends-overview.md @@ -26,8 +26,8 @@ Backends are the bridge between your exported model and the hardware it runs on. | [Vulkan ](/backends/vulkan/vulkan-overview.md) | Android | GPU | Android GPU acceleration | | [Qualcomm](backends-qualcomm) | Android | NPU | Qualcomm SoCs | | [MediaTek](backends-mediatek) | Android | NPU | MediaTek SoCs | -| [ARM EthosU](backends-arm-ethos-u) | Embedded | NPU | ARM MCUs | -| [ARM VGF](backends-arm-vgf) | Android | NPU | ARM platforms | +| [Arm Ethos-U](/backends/arm-ethos-u/arm-ethos-u-overview.md) | Embedded | NPU | Arm MCUs | +| [Arm VGF](/backends/arm-vgf/arm-vgf-overview.md) | Android | GPU | Arm platforms | | [OpenVINO](build-run-openvino) | Embedded | CPU/GPU/NPU | Intel SoCs | | [NXP](backends-nxp) | Embedded | NPU | NXP SoCs | | [Cadence](backends-cadence) | Embedded | DSP | DSP-optimized workloads | @@ -56,8 +56,8 @@ backends/mps/mps-overview backends/vulkan/vulkan-overview backends-qualcomm backends-mediatek -backends-arm-ethos-u -backends-arm-vgf +backends/arm-ethos-u/arm-ethos-u-overview +backends/arm-vgf/arm-vgf-overview build-run-openvino backends-nxp backends-cadence diff --git a/docs/source/backends-qualcomm.md b/docs/source/backends-qualcomm.md index 7eb0405b309..6c5397f02be 100644 --- a/docs/source/backends-qualcomm.md +++ b/docs/source/backends-qualcomm.md @@ -71,7 +71,7 @@ This example is verified with SM8550 and SM8450. ### Software: - Follow ExecuTorch recommended Python version. - - A compiler to compile AOT parts, e.g., the GCC compiler comes with Ubuntu LTS. + - A compiler to compile AOT parts, e.g., the GCC compiler comes with Ubuntu LTS. g++ version need to be 13 or higher. - [Android NDK](https://developer.android.com/ndk). This example is verified with NDK 26c. - (Optional) Target toolchain for linux embedded platform. - [Qualcomm AI Engine Direct SDK](https://developer.qualcomm.com/software/qualcomm-ai-engine-direct-sdk) diff --git a/docs/source/backends-arm-ethos-u.md b/docs/source/backends/arm-ethos-u/arm-ethos-u-overview.md similarity index 57% rename from docs/source/backends-arm-ethos-u.md rename to docs/source/backends/arm-ethos-u/arm-ethos-u-overview.md index 2dfddacd20f..6823473afb8 100644 --- a/docs/source/backends-arm-ethos-u.md +++ b/docs/source/backends/arm-ethos-u/arm-ethos-u-overview.md @@ -1,4 +1,4 @@ -# Arm® Ethos™-U NPU Backend +# Arm Ethos-U Backend The Arm® Ethos™-U backend targets Edge/IoT-type AI use-cases by enabling optimal execution of quantized models on [Arm® Ethos™-U55 NPU](https://www.arm.com/products/silicon-ip-cpu/ethos/ethos-u55), [Arm® Ethos™-U65 NPU](https://www.arm.com/products/silicon-ip-cpu/ethos/ethos-u65), and @@ -6,18 +6,18 @@ The Arm® Ethos™-U backend targets Edge/IoT-type AI use-cases by enabli [ethos-u-vela](https://pypi.org/project/ethos-u-vela/) graph compiler. This document is a technical reference for using the Ethos-U backend, for a top level view with code examples please refer to the [Arm Ethos-U Backend Tutorial](https://docs.pytorch.org/executorch/stable/tutorial-arm-ethos-u.html). - ## Features + - Wide operator support for delegating large parts of models to highly optimized and low power Ethos-U NPUs. - A quantizer that optimizes quantization for the NPU target. - Example runtime integration for easy hardware bringup. - ## Target Requirements -The target system must include an Ethos-U NPU. +The target system must include an Ethos-U NPU. ## Development Requirements + ```{tip} All requirements can be downloaded using `examples/arm/setup.sh --i-agree-to-the-contained-eula` and added to the path using set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}") @@ -35,17 +35,17 @@ And for building and running the example application available in `examples/arm/ Fixed Virtual Platforms (FVPs) are freely available emulators provided by Arm for easy embedded development without the need for a physical development board. +## Using the Arm Ethos-U Backend -## Using the Arm Ethos-U backend The main configuration point for the lowering is the `EthosUCompileSpec` consumed by the partitioner and quantizer. The full user-facing API is documented below. ```python class EthosUCompileSpec(target: str, system_config: str | None = None, memory_mode: str | None = None, extra_flags: list[str] | None = None, config_ini: str | None = 'Arm/vela.ini') ``` -Compile spec for Ethos-U NPU +Compile spec for Ethos-U NPU. -Attributes: +Args: - **target**: Ethos-U accelerator configuration, e.g. ethos-u55-128. - **system_config**: System configuration to select from the Vela configuration file. - **memory_mode**: Memory mode to select from the Vela configuration file. @@ -57,111 +57,43 @@ def EthosUCompileSpec.dump_debug_info(self, debug_mode: executorch.backends.arm. ``` Dump debugging information into the intermediates path. -```python -def EthosUCompileSpec.dump_intermediate_artifacts_to(self, output_path: str | None): -``` -Sets a path for dumping intermediate results during lowering such as tosa and pte. - -```python -def EthosUCompileSpec.get_intermediate_path(self) -> str | None: -``` -Returns the path for dumping intermediate results during lowering such as tosa and pte. - -```python -def EthosUCompileSpec.get_output_format() -> str: -``` -Returns a constant string that is the output format of the class. - +Args: +- **debug_mode**: The debug mode to use for dumping debug information. -### Partitioner API ```python -class EthosUPartitioner(compile_spec: executorch.backends.arm.ethosu.compile_spec.EthosUCompileSpec, additional_checks: Optional[Sequence[torch.fx.passes.operator_support.OperatorSupportBase]] = None) -> None +def EthosUCompileSpec.dump_intermediate_artifacts_to(self, output_path: str | None): ``` -Partitions subgraphs supported by the Arm Ethos-U backend. +Sets a path for dumping intermediate results during such as tosa and pte. -Attributes: -- **compile_spec**: List of CompileSpec objects for Ethos-U backend. -- **additional_checks**: Optional sequence of additional operator support checks. +Args: +- **output_path**: Path to dump intermediate results to. ```python -def EthosUPartitioner.ops_to_not_decompose(self, ep: torch.export.exported_program.ExportedProgram) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.node.Node], bool]]]: +def EthosUCompileSpec.get_intermediate_path(self) -> str | None: ``` -Returns a list of operator names that should not be decomposed. When these ops are -registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be -guaranteed that the program that the backend receives will not have any of these ops -decomposed. +Gets the path used for dumping intermediate results such as tosa and pte. Returns: -- **List[torch._ops.OpOverload]**: a list of operator names that should not be decomposed. -- **Optional[Callable[[torch.fx.Node], bool]]]**: an optional callable, acting as a filter, that users can provide - which will be called for each node in the graph that users can use as a filter for certain - nodes that should be continued to be decomposed even though the op they correspond to is - in the list returned by ops_to_not_decompose. + Path where intermediate results are saved. ```python -def EthosUPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult: +def EthosUCompileSpec.get_output_format() -> str: ``` -Returns the input exported program with newly created sub-Modules encapsulating -specific portions of the input "tagged" for delegation. +Returns a constant string that is the output format of the class. -The specific implementation is free to decide how existing computation in the -input exported program should be delegated to one or even more than one specific -backends. -The contract is stringent in that: -* Each node that is intended to be delegated must be tagged -* No change in the original input exported program (ExportedProgram) representation can take -place other than adding sub-Modules for encapsulating existing portions of the -input exported program and the associated metadata for tagging. -Args: -- **exported_program**: An ExportedProgram in Edge dialect to be partitioned for backend delegation. +### Partitioner API -Returns: -- **PartitionResult**: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers. +See [Partitioner API](arm-ethos-u-partitioner.md) for more information of the Partitioner API. +## Quantization -### Quantizer -Since the Ethos-U backend is integer-only, all ops intended to run on the NPU needs to be quantized. The Ethos-U quantizer supports +Since the Ethos-U backend is integer-only, all operators intended be executed on the NPU needs to be quantized. The Ethos-U quantizer supports [Post Training Quantization (PT2E)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) and [Quantization-Aware Training (QAT)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_qat.html) quantization. -Currently, the symmetric `int8` config defined by `executorch.backends.arm.quantizer.arm_quantizer.get_symmetric_quantization_config` is -the main config available to use with the Ethos-U quantizer. - -```python -class EthosUQuantizer(compile_spec: 'EthosUCompileSpec') -> 'None' -``` - -```python -def EthosUQuantizer.set_global(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': -``` -Set quantization_config for submodules that are not already annotated by name or type filters. - -```python -def EthosUQuantizer.set_io(self, quantization_config): -``` -Set quantization_config for input and output nodes. - -```python -def EthosUQuantizer.set_module_name(self, module_name: 'str', quantization_config: 'Optional[QuantizationConfig]') -> 'TOSAQuantizer': -``` -Set quantization_config for a submodule with name: `module_name`, for example: -quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator -patterns in the submodule with this module name with the given `quantization_config` - -```python -def EthosUQuantizer.set_module_type(self, module_type: 'Callable', quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': -``` -Set quantization_config for a submodule with type: `module_type`, for example: -quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator -patterns in the submodule with this module type with the given `quantization_config` - -```python -def EthosUQuantizer.transform_for_annotation(self, model: 'GraphModule') -> 'GraphModule': -``` -An initial pass for transforming the graph to prepare it for annotation. - +For more information on quantization, see [Quantization](arm-ethos-u-quantization.md) ## Runtime Integration @@ -169,8 +101,8 @@ An example runtime application is available in [examples/arm/executor_runner](ht The example application is recommended to use for testing basic functionality of your lowered models, as well as a starting point for developing runtime integrations for your own targets. For an in-depth explanation of the architecture of the executor_runner and the steps required for doing such an integration, please refer to [Ethos-U porting guide](https://github.com/pytorch/executorch/blob/main/examples/arm/ethos-u-porting-guide.md). - ### Ethos-U memory modes + The Ethos-U NPU provides two distinct memory interfaces: - One interface for **low-latency, high-bandwidth memory**. - On all Ethos-U NPUs(Ethos-U55, Ethos-U65, Ethos-U85), the low-latency memory is usually the SRAM of the SoC. @@ -195,6 +127,7 @@ The placement of the scratch buffer and the Neural Network determine the memory Here is an in-depth explanation of the different modes: #### 1. Sram-Only Memory Mode + - Ethos-U scratch buffer resides in the SRAM. - Neural Network resides in the SRAM. - Ethos-U fast scratch buffer is not used. @@ -209,6 +142,7 @@ Below, you can see a visual representation of the placement of the two logical m ![](backend-arm-ethos-u-sram_only.png) #### 2. Shared-Sram Memory Mode + - Ethos-U scratch buffer resides in the SRAM. - Neural Network resides in the External memory. - Ethos-U fast scratch buffer is not used. @@ -225,6 +159,7 @@ Below, you can see a visual representation of the placement of the two logical m ![](backend-arm-ethos-u-shared_sram.png) #### 3. Dedicated-Sram Memory Mode + - Ethos-U scratch buffer resides in the External memory. - Neural Network resides in the External memory. - Ethos-U fast scratch buffer resides in the on-chip memory. @@ -241,45 +176,27 @@ Below, you can see a visual representation of the placement of the two logical m ![](backend-arm-ethos-u-dedicated_sram.png) - The memory modes are defined within the [vela.ini file](https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/ethosu/config_files/Arm/vela.ini?ref_type=heads). When you install ExecuTorch for the Ethos-U backend, you automatically install the compiler containing the vela.ini file so you can directly create a compile specification with these memory modes. -#### Interpreting the output from the Ethos-U compiler regarding the memory footprint -As part of the `to_edge_transform_and_lower` step, you will see a memory footprint information presented as: -``` -Total SRAM used 2467.27 KiB -Total Off-chip Flash used 12.20 KiB -```` +## Reference -The `Total SRAM used` indicates the peak SRAM utilization needed by the NPU in order to perform an inference. In the snippet above, the Ethos-U compiler requires 2467.27 KiB of SRAM in order to schedule the inference. -Therefore, from an application standpoint, you need to ensure you have at least 2467.27 KiB of SRAM on the SoC to run this model. The Ethos-U compiler provides a scheduling algorithm allowing to -lower the peak SRAM usage within reasonable limits, you need to add the `--optimise Size` or `--arena-cache-size` CLI options for to the compile spec. You can read more about the options of the -Ethos-U compiler in the documentation [here](https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md#optimise). If the peak SRAM usage remains too high in -Shared Sram memory mode, you would need to us the Dedicated Sram mode in order to store the Neural Network and the Ethos-U scratch buffer in the external memory. -The main advantage of the Dedicated_Sram memory mode is that you can run large models and still benefit from the low-latency/high-bandwidth of the SRAM, used as a cache. -It is important to highlight that when you specify a memory mode in the compile spec, in the runtime, the user is expected to place the scratch buffer and NN in the correct memory location. -In other words, when you specify for ex. Shared Sram memory mode, the runtime application logic should place the ethos-U scratch buffer in the on-chip memory and the NN in the external memory for optimal performance. -You can see how this coupling between the memory mode and runtime application is done in the -[Ethos-U porting guide](https://github.com/pytorch/executorch/blob/main/examples/arm/ethos-u-porting-guide.md) +**→{doc}`/backends/arm-ethos-u/arm-ethos-u-partitioner` — Partitioner options.** +**→{doc}`/backends/arm-ethos-u/arm-ethos-u-quantization` — Supported quantization schemes.** -### Bundled.io and ETdump +**→{doc}`/backends/arm-ethos-u/arm-ethos-u-troubleshooting` — Troubleshooting and common issues.** -The arm_executor_runner supports [bundled-io](https://docs.pytorch.org/executorch/0.4/bundled-io.html) and [ETdump](https://docs.pytorch.org/executorch/stable/etdump.html) debugging tools. +**→{doc}`/backends/arm-ethos-u/tutorials/arm-ethos-u-tutorials` — Tutorials.** -To enable bundled-io, set `EXECUTORCH_BUILD_DEVTOOLS` when building Executorch and `DET_BUNDLE_IO` when building the executor_runner. To enable ETdump, set `EXECUTORCH_BUILD_ARM_ETDUMP` when building Executorch and `DEXECUTORCH_ENABLE_EVENT_TRACER` -when building the executor_runner. +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Arm Ethos-U Backend -## Memory formats - -Tensors of rank 4 and higher have two differing [memory format](https://pytorch.org/blog/tensor-memory-format-matters/) standards used. -Pytorch defaults to contiguous/ channels first/ NCHW memory formats, compared to TOSA which only supports channels last/NHWC memory format. -To support this, the backend inserts a transpose in the beginning if the incoming memory format is contiguous, and correspondingly a -transpose in the end if the outgoing memory format is contiguous. Note that this means that you may avoid transposing the data unneccessarily if the runtime integration and -full network is converted to use channels last. A word of caution must be given here however - changing memory format has been noted to have side effects such as -unsupported ops being inserted into the graph, and it is currently not widely tested, so the feature must so far be viewed as experimental. - -## See Also -- [Arm Ethos-U Backend Tutorial](tutorial-arm-ethos-u.md) \ No newline at end of file +arm-ethos-u-partitioner +arm-ethos-u-quantization +arm-ethos-u-troubleshooting +tutorials/arm-ethos-u-tutorials +``` diff --git a/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md b/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md new file mode 100644 index 00000000000..09664bd6ccc --- /dev/null +++ b/docs/source/backends/arm-ethos-u/arm-ethos-u-partitioner.md @@ -0,0 +1,47 @@ +# Partitioner API + +The `EthosUPartitioner` controls what parts of a model is delegated to the Arm Ethos-U backend. Below is a reference of the various functions the partitioner provides: + +```python +class EthosUPartitioner(compile_spec: executorch.backends.arm.ethosu.compile_spec.EthosUCompileSpec, additional_checks: Optional[Sequence[torch.fx.passes.operator_support.OperatorSupportBase]] = None) -> None +``` +Partitions subgraphs supported by the Arm Ethos-U backend. + +Args: +- **compile_spec**: List of CompileSpec objects for Ethos-U backend. +- **additional_checks**: Optional sequence of additional operator support checks. + +```python +def EthosUPartitioner.ops_to_not_decompose(self, ep: torch.export.exported_program.ExportedProgram) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.node.Node], bool]]]: +``` +Return operators and a filter that should not be decomposed. + +Provide a base set of ops to preserve as-is and a predicate that keeps +certain activations whole when surrounded by quantize/dequantize ops in +a quantized graph. This helps downstream TOSA lowering and delegation. + +Args: +- **ep (ExportedProgram)**: Program used to infer target-specific policy. + +Returns: +- **Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]**: + A list of op overloads to keep intact, and an optional filter + function that returns True when an op should not be decomposed. + +```python +def EthosUPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult: +``` +Partition the program and tag TOSA-compatible subgraphs. + +Run the FX capability-based partitioner to propose subgraphs, then +refine tags by removing boundary-only quantize/dequantize nodes and by +rejecting partitions that would lower to no-ops. Emit a detailed report +of rejected nodes and their reasons. + +Args: +- **exported_program (ExportedProgram)**: Program to analyze and + partition. + +Returns: +- **PartitionResult**: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. diff --git a/docs/source/backends/arm-ethos-u/arm-ethos-u-quantization.md b/docs/source/backends/arm-ethos-u/arm-ethos-u-quantization.md new file mode 100644 index 00000000000..80ac51ad644 --- /dev/null +++ b/docs/source/backends/arm-ethos-u/arm-ethos-u-quantization.md @@ -0,0 +1,71 @@ +# Quantization + +The Arm Ethos-U delegate only supports the execution of quantized models. To quantize a model so that is supported by this delegate, the `EthosUQuantizer` should be used. + +Currently, the symmetric `int8` config defined by `executorch.backends.arm.quantizer.arm_quantizer.get_symmetric_quantization_config` is the main config available to use with the Ethos-U quantizer. + +### Supported Quantization Schemes + +The Arm Ethos-U delegate supports the following quantization schemes: + +- 8-bit symmetric weights with 8-bit asymmetric activations (via the PT2E quantization flow). +- Limited support for 16-bit quantization with 16-bit activations and 8-bit weights (a.k.a 16x8 quantization). This is under development. + +### Quantization API + +```python +class EthosUQuantizer(compile_spec: 'EthosUCompileSpec') -> 'None' +``` +Quantizer supported by the Arm Ethos-U backend. + +Args: +- **compile_spec**: A EthosUCompileSpec instance. + +```python +def EthosUQuantizer.set_global(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for submodules that are not already annotated by name or type filters. + +Args: +- **quantization_config**: The QuantizationConfig to set as global configuration. + +```python +def EthosUQuantizer.set_io(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for input and output nodes. + +Args: +- **quantization_config**: The QuantizationConfig to set for input and output nodes. + +```python +def EthosUQuantizer.set_module_name(self, module_name: 'str', quantization_config: 'Optional[QuantizationConfig]') -> 'TOSAQuantizer': +``` +Set quantization_config for a submodule with name: `module_name`, for example: +quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator +patterns in the submodule with this module name with the given `quantization_config` + +Args: +- **module_name**: The name of the submodule to set the quantization config for. +- **quantization_config**: The QuantizationConfig to set for the submodule. + +```python +def EthosUQuantizer.set_module_type(self, module_type: 'Callable', quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for a submodule with type: `module_type`, for example: +quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator +patterns in the submodule with this module type with the given `quantization_config`. + +Args: +- **module_type**: The type of the submodule to set the quantization config for. +- **quantization_config**: The QuantizationConfig to set for the submodule. + +```python +def EthosUQuantizer.transform_for_annotation(self, model: 'GraphModule') -> 'GraphModule': +``` +An initial pass for transforming the graph to prepare it for annotation. +Currently transforms scalar values to tensor attributes. + +Args: +- **model**: The model to transform. +Returns: + The transformed model. diff --git a/docs/source/backends/arm-ethos-u/arm-ethos-u-troubleshooting.md b/docs/source/backends/arm-ethos-u/arm-ethos-u-troubleshooting.md new file mode 100644 index 00000000000..af31ed7fd0c --- /dev/null +++ b/docs/source/backends/arm-ethos-u/arm-ethos-u-troubleshooting.md @@ -0,0 +1,38 @@ +# Arm Ethos-U Troubleshooting + +This page describes common issues that you may encounter when using the Arm Ethos-U backend and how to debug and resolve them. + +## Understanding memory footprint using the Ethos-U compiler + +As part of the `to_edge_transform_and_lower` step, you will see a memory footprint information presented as: + +``` +Total SRAM used 2467.27 KiB +Total Off-chip Flash used 12.20 KiB +``` + +The `Total SRAM used` indicates the peak SRAM utilization needed by the NPU in order to perform an inference. In the snippet above, the Ethos-U compiler requires 2467.27 KiB of SRAM in order to schedule the inference. +Therefore, from an application standpoint, you need to ensure you have at least 2467.27 KiB of SRAM on the SoC to run this model. The Ethos-U compiler provides a scheduling algorithm allowing to +lower the peak SRAM usage within reasonable limits, you need to add the `--optimise Size` or `--arena-cache-size` CLI options for to the compile spec. You can read more about the options of the +Ethos-U compiler in the documentation [here](https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md#optimise). If the peak SRAM usage remains too high in +Shared Sram memory mode, you would need to us the Dedicated Sram mode in order to store the Neural Network and the Ethos-U scratch buffer in the external memory. +The main advantage of the Dedicated_Sram memory mode is that you can run large models and still benefit from the low-latency/high-bandwidth of the SRAM, used as a cache. +It is important to highlight that when you specify a memory mode in the compile spec, in the runtime, the user is expected to place the scratch buffer and NN in the correct memory location. +In other words, when you specify for ex. Shared Sram memory mode, the runtime application logic should place the ethos-U scratch buffer in the on-chip memory and the NN in the external memory for optimal performance. +You can see how this coupling between the memory mode and runtime application is done in the +[Ethos-U porting guide](https://github.com/pytorch/executorch/blob/main/examples/arm/ethos-u-porting-guide.md) + +## Using Bundled.io and ETdump + +The arm_executor_runner supports [bundled-io](https://docs.pytorch.org/executorch/0.4/bundled-io.html) and [ETdump](https://docs.pytorch.org/executorch/stable/etdump.html) debugging tools. + +To enable bundled-io, set `EXECUTORCH_BUILD_DEVTOOLS` when building Executorch and `DET_BUNDLE_IO` when building the executor_runner. To enable ETdump, set `EXECUTORCH_BUILD_ARM_ETDUMP` when building Executorch and `DEXECUTORCH_ENABLE_EVENT_TRACER` when building the executor_runner. + +## Issues with memory formats + +Tensors of rank 4 and higher have two differing [memory format](https://pytorch.org/blog/tensor-memory-format-matters/) standards used. +PyTorch defaults to contiguous/ channels first/ NCHW memory formats, compared to TOSA which only supports channels last/NHWC memory format. +To support this, the backend inserts a transpose in the beginning if the incoming memory format is contiguous, and correspondingly a +transpose in the end if the outgoing memory format is contiguous. Note that this means that you may avoid transposing the data unneccessarily if the runtime integration and +full network is converted to use channels last. A word of caution must be given here however - changing memory format has been noted to have side effects such as +unsupported ops being inserted into the graph, and it is currently not widely tested, so the feature must so far be viewed as experimental. diff --git a/docs/source/backend-arm-ethos-u-dedicated_sram.png b/docs/source/backends/arm-ethos-u/backend-arm-ethos-u-dedicated_sram.png similarity index 100% rename from docs/source/backend-arm-ethos-u-dedicated_sram.png rename to docs/source/backends/arm-ethos-u/backend-arm-ethos-u-dedicated_sram.png diff --git a/docs/source/backend-arm-ethos-u-shared_sram.png b/docs/source/backends/arm-ethos-u/backend-arm-ethos-u-shared_sram.png similarity index 100% rename from docs/source/backend-arm-ethos-u-shared_sram.png rename to docs/source/backends/arm-ethos-u/backend-arm-ethos-u-shared_sram.png diff --git a/docs/source/backend-arm-ethos-u-sram_only.png b/docs/source/backends/arm-ethos-u/backend-arm-ethos-u-sram_only.png similarity index 100% rename from docs/source/backend-arm-ethos-u-sram_only.png rename to docs/source/backends/arm-ethos-u/backend-arm-ethos-u-sram_only.png diff --git a/docs/source/backends/arm-ethos-u/tutorials/arm-ethos-u-tutorials.md b/docs/source/backends/arm-ethos-u/tutorials/arm-ethos-u-tutorials.md new file mode 100644 index 00000000000..4b540f2179d --- /dev/null +++ b/docs/source/backends/arm-ethos-u/tutorials/arm-ethos-u-tutorials.md @@ -0,0 +1,10 @@ +# Arm Ethos-U Backend Tutorials + +**→{doc}`ethos-u-getting-started`** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Tutorials + +ethos-u-getting-started diff --git a/docs/source/tutorial-arm-ethos-u.md b/docs/source/backends/arm-ethos-u/tutorials/ethos-u-getting-started.md similarity index 99% rename from docs/source/tutorial-arm-ethos-u.md rename to docs/source/backends/arm-ethos-u/tutorials/ethos-u-getting-started.md index 0e48cd466a0..6c078b0f251 100644 --- a/docs/source/tutorial-arm-ethos-u.md +++ b/docs/source/backends/arm-ethos-u/tutorials/ethos-u-getting-started.md @@ -1,4 +1,4 @@ -# Arm Ethos-U NPU Backend Tutorial +# Getting Started Tutorial ::::{grid} 2 @@ -66,7 +66,6 @@ As a simple check that your environment is set up correctly, run `which FVP_Cors The ExecuTorch Ahead-of-Time (AOT) pipeline takes a PyTorch Model (a `torch.nn.Module`) and produces a `.pte` binary file, which is then consumed by the ExecuTorch Runtime. This [document](getting-started-architecture.md) goes in much more depth about the ExecuTorch software stack for both AoT as well as Runtime. The example below shows how to quantize a model consisting of a single addition, and export it it through the AOT flow using the EthosU backend. For more details, see `examples/arm/ethos_u_minimal_example.ipynb`. - ```python import torch @@ -186,7 +185,6 @@ The block diagram below shows, at the high level, how the various build artifact ![](arm-delegate-runtime-build.svg) - ## Running on Corstone FVP Platforms Finally, use the `backends/arm/scripts/run_fvp.sh` utility script to run the .elf-file on simulated Arm hardware. diff --git a/docs/source/backends/arm-vgf/arm-vgf-overview.md b/docs/source/backends/arm-vgf/arm-vgf-overview.md new file mode 100644 index 00000000000..4cbc6c44305 --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-overview.md @@ -0,0 +1,114 @@ +# Arm VGF Backend + +The Arm® VGF backend is the ExecuTorch solution for lowering PyTorch models to VGF compatible hardware. +It leverages the TOSA operator set and the [ML SDK for Vulkan®](https://github.com/arm/ai-ml-sdk-for-vulkan?tab=readme-ov-file) to produce a .PTE file. +The VGF backend also supports execution from a .PTE file and provides functionality to extract the corresponding VGF file for integration into various applications. + +## Features + +- Wide operator support for delegating large parts of models to the VGF target. +- A quantizer that optimizes quantization for the VGF target. + +## Target Requirements + +The target system must include ML SDK for Vulkan and a Vulkan driver with Vulkan API >= 1.3. + +## Development Requirements + +```{tip} +All requirements can be downloaded using `examples/arm/setup.sh --enable-mlsdk-deps --disable-ethos-u-deps` and added to the path using +`source examples/arm/ethos-u-scratch/setup_path.sh` +``` + +For the AOT flow, compilation of a model to `.pte` format using the VGF backend, the requirements are: +- [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR. +- [ML SDK Model Converter](https://github.com/arm/ai-ml-sdk-model-converter) for converting TOSA flatbuffers to VGF files. + +And for building and running your application using the generic executor_runner: +- [Vulkan API](https://www.vulkan.org) should be set up locally for GPU execution support. +- [ML Emulation Layer for Vulkan](https://github.com/arm/ai-ml-emulation-layer-for-vulkan) for testing on Vulkan API. + +## Using the Arm VGF Backend + +The [VGF Minimal Example](https://github.com/pytorch/executorch/blob/main/examples/arm/vgf_minimal_example.ipynb) demonstrates how to lower a module using the VGF backend. + +The main configuration point for the lowering is the `VgfCompileSpec` consumed by the partitioner and quantizer. +The full user-facing API is documented below. + +```python +class VgfCompileSpec(tosa_spec: executorch.backends.arm.tosa.specification.TosaSpecification | str | None = None, compiler_flags: list[str] | None = None) +``` +Compile spec for VGF compatible targets. + +Args: +- **tosa_spec**: TOSA specification that should be targeted. +- **compiler_flags**: Extra compiler flags for converter_backend. + +```python +def VgfCompileSpec.dump_debug_info(self, debug_mode: executorch.backends.arm.common.arm_compile_spec.ArmCompileSpec.DebugMode | None): +``` +Dump debugging information into the intermediates path. + +Args: +- **debug_mode**: The debug mode to use for dumping debug information. + +```python +def VgfCompileSpec.dump_intermediate_artifacts_to(self, output_path: str | None): +``` +Sets a path for dumping intermediate results during such as tosa and pte. + +Args: +- **output_path**: Path to dump intermediate results to. + +```python +def VgfCompileSpec.get_intermediate_path(self) -> str | None: +``` +Gets the path used for dumping intermediate results such as tosa and pte. + +Returns: + Path where intermediate results are saved. + +```python +def VgfCompileSpec.get_output_format() -> str: +``` +Returns a constant string that is the output format of the class. + + + +### Partitioner API + +See [Partitioner API](arm-vgf-partitioner.md) for more information of the Partitioner API. + +## Quantization + +The VGF quantizer supports [Post Training Quantization (PT2E)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) +and [Quantization-Aware Training (QAT)](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_qat.html). + +For more information on quantization, see [Quantization](arm-vgf-quantization.md). + +## Runtime Integration + +The VGF backend can use the default ExecuTorch runner. The steps required for building and running it are explained in the [VGF Backend Tutorial](tutorials/vgf-getting-started.md). +The example application is recommended to use for testing basic functionality of your lowered models, as well as a starting point for developing runtime integrations for your own targets. + +## Reference + +**→{doc}`/backends/arm-vgf/arm-vgf-partitioner` — Partitioner options.** + +**→{doc}`/backends/arm-vgf/arm-vgf-quantization` — Supported quantization schemes.** + +**→{doc}`/backends/arm-vgf/arm-vgf-troubleshooting` — Debug common issues.** + +**→{doc}`/backends/arm-vgf/tutorials/arm-vgf-tutorials` — Tutorials.** + + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Arm VGF Backend + +arm-vgf-partitioner +arm-vgf-quantization +arm-vgf-troubleshooting +tutorials/arm-vgf-tutorials +``` diff --git a/docs/source/backends/arm-vgf/arm-vgf-partitioner.md b/docs/source/backends/arm-vgf/arm-vgf-partitioner.md new file mode 100644 index 00000000000..e3cbd2f9d22 --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-partitioner.md @@ -0,0 +1,47 @@ +# Partitioner API + +The `VgfPartitioner` controls what parts of a model is delegated to the Arm VGF backend. Below is a reference of the various functions the partitioner provides: + +```python +class VgfPartitioner(compile_spec: executorch.backends.arm.vgf.compile_spec.VgfCompileSpec, additional_checks: Optional[Sequence[torch.fx.passes.operator_support.OperatorSupportBase]] = None) -> None +``` +Partitions subgraphs supported by the Arm Vgf backend. + +Args: +- **compile_spec**: The Vgf compilation specification. +- **additional_checks**: Optional sequence of additional operator support checks. + +```python +def VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported_program.ExportedProgram) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.node.Node], bool]]]: +``` +Return operators and a filter that should not be decomposed. + +Provide a base set of ops to preserve as-is and a predicate that keeps +certain activations whole when surrounded by quantize/dequantize ops in +a quantized graph. This helps downstream TOSA lowering and delegation. + +Args: +- **ep (ExportedProgram)**: Program used to infer target-specific policy. + +Returns: +- **Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]**: + A list of op overloads to keep intact, and an optional filter + function that returns True when an op should not be decomposed. + +```python +def VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult: +``` +Partition the program and tag TOSA-compatible subgraphs. + +Run the FX capability-based partitioner to propose subgraphs, then +refine tags by removing boundary-only quantize/dequantize nodes and by +rejecting partitions that would lower to no-ops. Emit a detailed report +of rejected nodes and their reasons. + +Args: +- **exported_program (ExportedProgram)**: Program to analyze and + partition. + +Returns: +- **PartitionResult**: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. diff --git a/docs/source/backends/arm-vgf/arm-vgf-quantization.md b/docs/source/backends/arm-vgf/arm-vgf-quantization.md new file mode 100644 index 00000000000..68f77249885 --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-quantization.md @@ -0,0 +1,73 @@ +# Quantization + +The Arm VGF delegate can be used to execute quantized models. To quantize a model so that is supported by this delegate, the `VgfQuantizer` should be used. + +Currently the symmetric `int8` config defined by `executorch.backends.arm.quantizer.arm_quantizer.get_symmetric_quantization_config` is the main config available to use with the VGF quantizer. + +### Supported Quantization Schemes + +The quantization schemes supported by the VGF Backend are: +- 8-bit symmetric weights with 8-bit asymmetric activations (via the PT2E quantization flow). + - Supports both static and dynamic activations + - Supports per-channel and per-tensor schemes + +Weight-only quantization is not currently supported on the VGF backend. + +### Quantization API + +```python +class VgfQuantizer(compile_spec: 'VgfCompileSpec') -> 'None' +``` +Quantizer supported by the Arm Vgf backend. + +Args: +- **compile_spec**: A VgfCompileSpec instance. + +```python +def VgfQuantizer.set_global(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for submodules that are not already annotated by name or type filters. + +Args: +- **quantization_config**: The QuantizationConfig to set as global configuration. + +```python +def VgfQuantizer.set_io(self, quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for input and output nodes. + +Args: +- **quantization_config**: The QuantizationConfig to set for input and output nodes. + +```python +def VgfQuantizer.set_module_name(self, module_name: 'str', quantization_config: 'Optional[QuantizationConfig]') -> 'TOSAQuantizer': +``` +Set quantization_config for a submodule with name: `module_name`, for example: +quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator +patterns in the submodule with this module name with the given `quantization_config` + +Args: +- **module_name**: The name of the submodule to set the quantization config for. +- **quantization_config**: The QuantizationConfig to set for the submodule. + +```python +def VgfQuantizer.set_module_type(self, module_type: 'Callable', quantization_config: 'QuantizationConfig') -> 'TOSAQuantizer': +``` +Set quantization_config for a submodule with type: `module_type`, for example: +quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator +patterns in the submodule with this module type with the given `quantization_config`. + +Args: +- **module_type**: The type of the submodule to set the quantization config for. +- **quantization_config**: The QuantizationConfig to set for the submodule. + +```python +def VgfQuantizer.transform_for_annotation(self, model: 'GraphModule') -> 'GraphModule': +``` +An initial pass for transforming the graph to prepare it for annotation. +Currently transforms scalar values to tensor attributes. + +Args: +- **model**: The model to transform. +Returns: + The transformed model. diff --git a/docs/source/backends/arm-vgf/arm-vgf-troubleshooting.md b/docs/source/backends/arm-vgf/arm-vgf-troubleshooting.md new file mode 100644 index 00000000000..6100bc94b0c --- /dev/null +++ b/docs/source/backends/arm-vgf/arm-vgf-troubleshooting.md @@ -0,0 +1,7 @@ +# Arm VGF Troubleshooting + +This page describes common issues that you may encounter when using the Arm VGF backend and how to debug and resolve them. + +## How do you visualize VGF files + +The [VGF Adapter for Model Explorer](https://github.com/arm/vgf-adapter-model-explorer) enables visualization of VGF files and can be useful for debugging. diff --git a/docs/source/backends/arm-vgf/tutorials/arm-vgf-tutorials.md b/docs/source/backends/arm-vgf/tutorials/arm-vgf-tutorials.md new file mode 100644 index 00000000000..ceb4304a814 --- /dev/null +++ b/docs/source/backends/arm-vgf/tutorials/arm-vgf-tutorials.md @@ -0,0 +1,10 @@ +# Arm VGF Backend Tutorials + +**→{doc}`vgf-getting-started`** + +```{toctree} +:maxdepth: 2 +:hidden: +:caption: Tutorials + +vgf-getting-started diff --git a/docs/source/tutorial-arm-vgf.md b/docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md similarity index 86% rename from docs/source/tutorial-arm-vgf.md rename to docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md index 0e34e4be4b6..f6015eaff74 100644 --- a/docs/source/tutorial-arm-vgf.md +++ b/docs/source/backends/arm-vgf/tutorials/vgf-getting-started.md @@ -1,4 +1,4 @@ -# Arm VGF Backend Tutorial +# Getting Started Tutorial ::::{grid} 2 @@ -29,7 +29,7 @@ If you are already familiar with this delegate, you may want to jump directly to * [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/examples/arm/aot_arm_compiler.py) ``` -This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on VGF targets. The tutorial is based on `vgf_minimal_example.ipyb`, provided in Arm®'s example folder. +This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on VGF targets. The tutorial is based on `vgf_minimal_example.ipyb`, provided in Arm's example folder. ## Prerequisites @@ -43,16 +43,14 @@ To enable development without a specific development board, we will be using the First, you will need to install ExecuTorch. Please follow the recommended tutorials if you haven't already, to set up a working ExecuTorch development environment. For the VGF backend it's recommended you [install from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html), or from a [nightly](https://download.pytorch.org/whl/nightly/executorch/). -Additionally, you need to install a number of SDK dependencies for generating VGF files. For glslc, prefer installing it via your package manager. If this is not possible, and for other dependencies, there are scripts to automate installation available in the main [ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm/). glscl will then be installed via the Vulkan SDK. - -To install VGF dependencies, run +In addition to this, you need to install a number of SDK dependencies for generating VGF files. Scripts to automate this are available in the main [ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm/). To install VGF dependencies, run ```bash ./examples/arm/setup.sh --i-agree-to-the-contained-eula --disable-ethos-u-deps --enable-mlsdk-deps ``` This will install: - [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR. - [ML SDK Model Converter](https://github.com/arm/ai-ml-sdk-model-converter) for converting TOSA flatbuffers to VGF files. -- [Vulkan API (If needed)](https://www.vulkan.org) Should be set up locally for GPU execution support. +- [Vulkan API](https://www.vulkan.org) should be set up locally for GPU execution support. - [ML Emulation Layer for Vulkan](https://github.com/arm/ai-ml-emulation-layer-for-vulkan) for testing on Vulkan API. @@ -67,13 +65,13 @@ As a simple check that your environment is set up correctly, run ```bash which model-converter ``` -Make sure the executable is located where you expect, in the `examples/arm` tree. +Make sure the executable is located where you expect, in the `examples/arm` tree. ## Build ### Ahead-of-Time (AOT) components -The ExecuTorch Ahead-of-Time (AOT) pipeline takes a PyTorch Model (a `torch.nn.Module`) and produces a `.pte` binary file, which is then typically consumed by the ExecuTorch Runtime. This [document](getting-started-architecture.md) goes in much more depth about the ExecuTorch software stack for both AoT as well as Runtime. +The ExecuTorch Ahead-of-Time (AOT) pipeline takes a PyTorch Model (a `torch.nn.Module`) and produces a `.pte` binary file, which is then typically consumed by the ExecuTorch Runtime. This [document](https://github.com/pytorch/executorch/blob/main/docs/source/getting-started-architecture.md) goes in much more depth about the ExecuTorch software stack for both AoT as well as Runtime. The example below shows how to quantize a model consisting of a single addition, and export it it through the AOT flow using the VGF backend. For more details, se `examples/arm/vgf_minimal_example.ipynb`. @@ -88,15 +86,15 @@ example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1)) model = Add() model = model.eval() -exported_program = torch.export.export_for_training(model, example_inputs) +exported_program = torch.export.export(model, example_inputs) graph_module = exported_program.graph_module -from executorch.backends.arm.vgf import VgfCompileSpec from executorch.backends.arm.quantizer import ( VgfQuantizer, get_symmetric_quantization_config, ) +from executorch.backends.arm.vgf import VgfCompileSpec from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e # Create a compilation spec describing the target for configuring the quantizer @@ -155,7 +153,7 @@ assert os.path.exists(pte_path), "Build failed; no .pte-file found" ```{tip} For a quick start, you can use the script `examples/arm/aot_arm_compiler.py` to produce the pte. To produce a pte file equivalent to the one above, run -`python -m examples.arm.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf` +`python -m examples.arm.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf` ``` ### Runtime: @@ -171,7 +169,6 @@ cmake \ -DCMAKE_BUILD_TYPE=Debug \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ @@ -208,14 +205,9 @@ In this tutorial you have learned how to use ExecuTorch to export a PyTorch mode ## FAQs -*glslc is not found when configuring the executor runner*. - -The Vulkan sdk is likely not in your path, check whether setup_path.sh contains something like +Issue: glslc is not found when configuring the executor runner. +Solution: The Vulkan sdk is likely not in your path, check whether setup_path.sh contains something like `export PATH=$(pwd)/examples/arm/ethos-u-scratch/vulkan_sdk/1.4.321.1/x86_64/bin:$PATH`. If not, add it and source the file. If you encountered any bugs or issues following this tutorial please file a bug/issue here on [Github](https://github.com/pytorch/executorch/issues/new). - -``` -Arm is a registered trademark of Arm Limited (or its subsidiaries or affiliates). -``` \ No newline at end of file diff --git a/docs/source/compiler-delegate-and-partitioner.md b/docs/source/compiler-delegate-and-partitioner.md index b057f3afa2e..c0449e7366b 100644 --- a/docs/source/compiler-delegate-and-partitioner.md +++ b/docs/source/compiler-delegate-and-partitioner.md @@ -131,7 +131,7 @@ static auto success_with_compiler = register_backend(backend); Providing consistent debugging experience, be it for runtime failures or performance profiling, is important. ExecuTorch employs native Developer Tools for this purpose, which enables correlating program instructions to original PyTorch code, via debug handles. You can read more about it [here](etrecord.rst). -Delegated program or subgraphs are opaque to ExecuTorch runtime and appear as a special `call_delegate` instruction, which asks corresponding backend to handle the execution of the subgraph or program. Due to the opaque nature of backend delgates, native Developer Tools does not have visibility into delegated program. Thus the debugging, functional or performance, experiences of delegated execution suffers significantly as compared to it's non-delegated counterpart. +Delegated program or subgraphs are opaque to ExecuTorch runtime and appear as a special `call_delegate` instruction, which asks corresponding backend to handle the execution of the subgraph or program. Due to the opaque nature of backend delegates, native Developer Tools does not have visibility into delegated program. Thus the debugging, functional or performance, experiences of delegated execution suffers significantly as compared to it's non-delegated counterpart. In order to provide consistent debugging experience to users, regardless of the use of delegation for a model, Developer Tools provide an interface to correlate delegated (sub)graph to original (sub)graph. The Developer Tools do so via debug handles map which allows delegates to generate internal handles that can be associated with the original (sub)graph consumed by the delegate. Then at runtime, backend developer can report error or profiling information using the internal handle, which will be mapped to original (sub)graph using the debug handle map. For more information, please refer to [Delegate Debugging](delegate-debugging.md). diff --git a/docs/source/conf.py b/docs/source/conf.py index 78268c8d053..f69fc243255 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -76,25 +76,40 @@ html_favicon = "_static/img/executorch-chip-logo.svg" -# Get ET_VERSION_DOCS during the build. -et_version_docs = os.environ.get("ET_VERSION_DOCS", None) -print(f"et_version_docs: {et_version_docs}") - -# The code below will cut version displayed in the dropdown like this: -# By default, set to "main". -# If it's a tag like refs/tags/v1.2.3-rc4 or refs/tags/v1.2.3, then -# cut to 1.2 -# the version varible is used in layout.html: https://github.com/pytorch/executorch/blob/main/docs/source/_templates/layout.html#L29 -version = release = "main" -if et_version_docs: - if et_version_docs.startswith("refs/tags/v"): - version = ".".join( - et_version_docs.split("/")[-1].split("-")[0].lstrip("v").split(".")[:2] - ) - elif et_version_docs.startswith("refs/heads/release/"): - version = et_version_docs.split("/")[-1] -print(f"Version: {version}") -html_title = " ".join((project, version, "documentation")) +# Import executorch version +# Adopted from PyTorch docs pattern +from executorch import version as et_version # type: ignore[attr-defined] + +executorch_version = str(et_version.__version__) + +# Check if this is a release build from environment variable +# The workflow sets RELEASE=true for tagged releases, RELEASE=false otherwise +# We need to properly parse the string as a boolean (any non-empty string is truthy in Python) +RELEASE = os.environ.get("RELEASE", "false").lower() == "true" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = "main" +# The full version, including alpha/beta/rc tags. +release = "main" + +# Customized html_title here. +# Default is " ".join(project, release, "documentation") if not set +if RELEASE: + # Turn 0.8.0a0+a90e907 into 0.8 + # Note: the release candidates should no longer have the aHASH suffix, but in any + # case we wish to leave only major.minor, even for rc builds. + version = ".".join(executorch_version.split("+")[0].split(".")[:2]) + html_title = " ".join((project, version, "documentation")) + release = version + +switcher_version = "main" if not RELEASE else version + +print(f"executorch_version: {executorch_version}") +print(f"Version: {version}, RELEASE: {RELEASE}") html_baseurl = "https://docs.pytorch.org/executorch/" # needed for sphinx-sitemap sitemap_locales = [None] @@ -176,8 +191,6 @@ # documentation. # -switcher_version = version - html_theme_options = { "logo": { "image_light": "_static/img/et-logo.png", @@ -242,6 +255,7 @@ "display_version": True, } + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". diff --git a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md index ae1b4f15c99..1168c4c04a3 100644 --- a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md +++ b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md @@ -26,7 +26,7 @@ Deploying large language models like Llama 3 on-device presents the following ch To address these, we apply the following optimizations: -1. Quantization: Use `QuantDtype.use_16a4w_block` for post-training quantization to reduce model size and memory usage. +1. Quantization: Apply the `quant_recipe` when setting the quantization config to reduce model size and memory usage. 2. Mixed Precision Quantization: compresses KV cache tensors to 8-bit and applies `QuantDtype.use_16a8w` to the LM head. @@ -48,9 +48,6 @@ class Llama3_2_3B_Instruct(LLMModelConfig): instruct_model = False num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 # Group size used in block quantization for weight quantization. Will only be used when ptq = 16a4w_block masked_softmax = False # SeqMSE Quantization: optimizes the parameter encodings of each layer of a model individually to minimize the difference between the layer’s original and quantized outputs. (Implementation details: ./backends/qualcomm/_passes/seq_mse.py) In this configuration, we set `seq_mse_candidates` = 0, which means SeqMSE quantization is not applied. @@ -58,10 +55,8 @@ class Llama3_2_3B_Instruct(LLMModelConfig): r1 = False r2 = False r3 = False - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - ) + # quant recipe + quant_recipe = Llama3_3BQuantRecipe ``` diff --git a/docs/source/new-contributor-guide.md b/docs/source/new-contributor-guide.md index d2074a3379f..ec5e67afc87 100644 --- a/docs/source/new-contributor-guide.md +++ b/docs/source/new-contributor-guide.md @@ -103,13 +103,6 @@ Before you can start writing any code, you need to get a copy of ExecuTorch code * The `origin` entries show your forked GitHub repository. They tell you that when you run `git pull` or `git push`, your changes will go from/to your GitHub fork. * The `upstream` entries show the main ExecuTorch repository. If you want to sync the latest changes from there, you can run `git fetch upstream`. - - Let's sync from both your fork _and_ the main ExecuTorch branch, getting the latest changes from each of them. To do this, run: - - ```bash - git fetch --all --prune - ``` - 4. If you just cloned your fork, your GitHub repository will tell you your branch is up-to-date: ![](_static/img/new-contributor-guide/synced_fork.png) diff --git a/docs/source/using-executorch-building-from-source.md b/docs/source/using-executorch-building-from-source.md index b060bab2746..09980a94cdc 100644 --- a/docs/source/using-executorch-building-from-source.md +++ b/docs/source/using-executorch-building-from-source.md @@ -41,7 +41,7 @@ toolchains, down to C++17. See [Runtime Overview](runtime-overview.md) for portability details. ## Environment Setup - Clone the ExecuTorch repository from GitHub and create a conda environment. Venv can be used in place on conda. + Clone the ExecuTorch repository from GitHub and create a conda environment. Venv can be used in place of conda. ```bash git clone -b viable/strict https://github.com/pytorch/executorch.git cd executorch diff --git a/examples/apple/coreml/scripts/extract_coreml_models.py b/examples/apple/coreml/scripts/extract_coreml_models.py index b3778a22625..593a270186b 100644 --- a/examples/apple/coreml/scripts/extract_coreml_models.py +++ b/examples/apple/coreml/scripts/extract_coreml_models.py @@ -21,7 +21,7 @@ def extract_coreml_models(pte_data: bytes): - program = deserialize_pte_binary(pte_data) + program = deserialize_pte_binary(pte_data).program delegates: List[BackendDelegate] = sum( [execution_plan.delegates for execution_plan in program.execution_plan], [] ) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 9c35e23d5dd..813101b77b7 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -19,7 +19,10 @@ from examples.devtools.scripts.export_bundled_program import save_bundled_program from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.ethosu import EthosUCompileSpec -from executorch.backends.arm.quantizer import get_symmetric_quantization_config +from executorch.backends.arm.quantizer import ( + get_symmetric_a16w8_quantization_config, + get_symmetric_quantization_config, +) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.util._factory import create_partitioner, create_quantizer @@ -32,8 +35,8 @@ from executorch.backends.arm.vgf import VgfCompileSpec # To use Cortex-M backend -from executorch.backends.cortex_m.passes.quantized_linear_fusion_pass import ( - QuantizedLinearFusionPass, +from executorch.backends.cortex_m.passes.convert_to_cortex_m_pass import ( + ConvertToCortexMPass, ) from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( @@ -228,6 +231,7 @@ def quantize( example_inputs: Tuple[torch.Tensor], evaluator_name: str | None, evaluator_config: Dict[str, Any] | None, + is_int16x8: bool = False, ) -> GraphModule: """This is the official recommended flow for quantization in pytorch 2.0 export. @@ -238,7 +242,18 @@ def quantize( quantizer = create_quantizer(compile_specs) - operator_config = get_symmetric_quantization_config() + if is_int16x8: + if compile_specs.tosa_spec.support_extension("int16"): + operator_config = get_symmetric_a16w8_quantization_config( + is_per_channel=True + ) + else: + raise ValueError( + f"Context TOSA spec {compile_specs.tosa_spec} doesn't support int16" + ) + else: + operator_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) m = prepare_pt2e(model, quantizer) @@ -356,6 +371,7 @@ def forward(self, x): "vgf", "TOSA-1.0+INT", "TOSA-1.0+FP", + "TOSA-1.0+INT+int16", ] @@ -681,20 +697,23 @@ def quantize_model( example_inputs: Tuple[torch.Tensor], compile_spec, ) -> Tuple[GraphModule, ExportedProgram]: - model_int8 = quantize( + + is_int16x8 = True if args.target == "TOSA-1.0+INT+int16" else False + model_quant = quantize( model, args.model_name, compile_spec, example_inputs, args.evaluate, args.evaluate_config, + is_int16x8, ) # Wrap quantized model back into an exported_program exported_program = torch.export.export( - model_int8, example_inputs, strict=args.strict_export + model_quant, example_inputs, strict=args.strict_export ) - return model_int8, exported_program + return model_quant, exported_program def to_edge_TOSA_delegate( @@ -715,9 +734,9 @@ def to_edge_TOSA_delegate( args.enable_debug_mode, ) - model_int8 = None + model_quant = None if args.quantize: - model_int8, exported_program = quantize_model( + model_quant, exported_program = quantize_model( args, model, example_inputs, compile_spec ) @@ -731,7 +750,7 @@ def to_edge_TOSA_delegate( ), ) - return model_int8, edge + return model_quant, edge def to_edge_no_delegate( @@ -740,7 +759,7 @@ def to_edge_no_delegate( model: GraphModule, example_inputs: Tuple[torch.Tensor], ): - model_int8 = None + model_quant = None if args.quantize: # As we can target multiple output encodings, one must # be specified. @@ -756,7 +775,7 @@ def to_edge_no_delegate( model, exported_program = quantize_model( args, model, example_inputs, compile_spec ) - model_int8 = model + model_quant = model edge = to_edge_transform_and_lower( exported_program, @@ -765,7 +784,7 @@ def to_edge_no_delegate( ), ) - return model_int8, edge + return model_quant, edge def transform_for_cortex_m_backend(edge_program_manager, args): @@ -776,7 +795,7 @@ def transform_for_cortex_m_backend(edge_program_manager, args): # Instantiate the mandatory ReplaceQuantNodesPass passes = [ReplaceQuantNodesPass] if args.enable_qdq_fusion_pass: - passes += [QuantizedLinearFusionPass, QuantizedOpFusionPass] + passes += [ConvertToCortexMPass, QuantizedOpFusionPass] current_edge = edge_program_manager for pass_cls in passes: transform_pass = ( @@ -818,13 +837,13 @@ def transform_for_cortex_m_backend(edge_program_manager, args): ) # Quantize if required - model_int8 = None + model_quant = None if args.delegate: - model_int8, edge = to_edge_TOSA_delegate( + model_quant, edge = to_edge_TOSA_delegate( exported_program, args, model, example_inputs ) else: - model_int8, edge = to_edge_no_delegate( + model_quant, edge = to_edge_no_delegate( exported_program, args, model, example_inputs ) @@ -884,7 +903,7 @@ def transform_for_cortex_m_backend(edge_program_manager, args): if args.bundleio: # Realize the quantization impact on numerics when generating reference output - reference_model = original_model if not model_int8 else model_int8 + reference_model = original_model if not model_quant else model_quant save_bpte_program(exec_prog, reference_model, output_file_name) print(f"Bundle PTE file saved as {output_file_name}") else: @@ -895,8 +914,9 @@ def transform_for_cortex_m_backend(edge_program_manager, args): evaluate_model( args.model_name, args.intermediates, + args.target, model_fp32, - model_int8, + model_quant, example_inputs, args.evaluate, args.evaluate_config, diff --git a/examples/arm/executor_runner/CMakeLists.txt b/examples/arm/executor_runner/CMakeLists.txt index d5038a1a6b8..bb5d3f59e1e 100644 --- a/examples/arm/executor_runner/CMakeLists.txt +++ b/examples/arm/executor_runner/CMakeLists.txt @@ -326,12 +326,29 @@ endif() # Need whole-archive to ensure C++ ctor's are called - this may be wasteful for # bin size as we link in a number of other symbols -target_link_libraries(arm_executor_runner ${arm_executor_runner_link}) +target_link_libraries(arm_executor_runner PUBLIC ${arm_executor_runner_link}) target_link_options( arm_executor_runner PUBLIC LINKER:-Map=arm_executor_runner.map ) +# Sanitizers +if(CMAKE_BUILD_TYPE MATCHES "UndefinedSanitizer") + set(_et_runner_ubsan_flag -fsanitize=undefined) + target_compile_options(arm_executor_runner PRIVATE ${_et_runner_ubsan_flag}) + target_link_options(arm_executor_runner PRIVATE ${_et_runner_ubsan_flag}) + if(NOT TARGET executorch_ubsan) + add_subdirectory( + ${ET_DIR_PATH}/examples/arm/ubsan + ${CMAKE_CURRENT_BINARY_DIR}/ubsan_runtime + ) + endif() + target_link_directories( + arm_executor_runner PRIVATE $ + ) + target_link_libraries(arm_executor_runner PRIVATE executorch_ubsan) +endif() + # ET headers and generated headers includes target_include_directories( arm_executor_runner diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index 696817450b5..928b0fc2a55 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -410,8 +410,7 @@ Error prepare_input_tensors( "Wrong number of inputs allocated compared to method"); #endif - EValue* input_evalues = - static_cast(allocator.allocate(num_inputs * sizeof(EValue*))); + EValue* input_evalues = allocator.allocateList(num_inputs); ET_CHECK_OR_RETURN_ERROR( input_evalues != nullptr, MemoryAllocationFailed, @@ -471,6 +470,10 @@ Error prepare_input_tensors( tensor.mutable_data_ptr() + tensor.numel(), 1); break; + default: + ET_LOG(Error, "Unhandled ScalarType"); + err = Error::InvalidArgument; + break; } } else { printf("Input[%d]: Not Tensor\n", i); diff --git a/examples/arm/executor_runner/arm_perf_monitor.cpp b/examples/arm/executor_runner/arm_perf_monitor.cpp index 58a47105743..35fd114f777 100644 --- a/examples/arm/executor_runner/arm_perf_monitor.cpp +++ b/examples/arm/executor_runner/arm_perf_monitor.cpp @@ -19,7 +19,7 @@ namespace { #if defined(ETHOSU55) || defined(ETHOSU65) const uint32_t ethosu_pmuCountersUsed = 4; #elif defined(ETHOSU85) -const uint32_t ethosu_pmuCountersUsed = 5; +const uint32_t ethosu_pmuCountersUsed = 7; #else #error No NPU target defined #endif @@ -65,11 +65,14 @@ void ethosu_inference_begin(struct ethosu_driver* drv, void*) { ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED); ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN); ETHOSU_PMU_Set_EVTYPER(drv, 4, ETHOSU_PMU_NPU_IDLE); - // Enable the 5 counters + ETHOSU_PMU_Set_EVTYPER(drv, 5, ETHOSU_PMU_MAC_ACTIVE); + ETHOSU_PMU_Set_EVTYPER(drv, 6, ETHOSU_PMU_WD_ACTIVE); + // Enable the 7 counters ETHOSU_PMU_CNTR_Enable( drv, ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | - ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk); + ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk | ETHOSU_PMU_CNT6_Msk | + ETHOSU_PMU_CNT7_Msk); #else #error No NPU target defined #endif @@ -214,7 +217,7 @@ void StopMeasurements(int num_inferences) { #elif defined(ETHOSU85) ET_LOG( Info, - "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); + "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE, ETHOSU_PMU_MAC_ACTIVE, ETHOSU_PMU_WD_ACTIVE]"); #else #error No NPU target defined #endif diff --git a/examples/arm/executor_runner/pte_to_header.py b/examples/arm/executor_runner/pte_to_header.py index 1b5fad05a12..65213bc729e 100644 --- a/examples/arm/executor_runner/pte_to_header.py +++ b/examples/arm/executor_runner/pte_to_header.py @@ -59,7 +59,7 @@ def input_file_path(path): if __name__ == "__main__": args = parser.parse_args() outfile = os.path.join(args.outdir, args.outfile) - attr = f'__attribute__((section("{args.section}"), aligned(16))) char ' + attr = f'__attribute__((section("{args.section}"), aligned(16))) unsigned char ' with open(args.pte, "rb") as fr, open(outfile, "w") as fw: data = fr.read() diff --git a/examples/arm/pruning_minimal_example.ipynb b/examples/arm/pruning_minimal_example.ipynb new file mode 100644 index 00000000000..78bb3f06b5b --- /dev/null +++ b/examples/arm/pruning_minimal_example.ipynb @@ -0,0 +1,566 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c0156802", + "metadata": {}, + "source": [ + "# Copyright 2025 Arm Limited and/or its affiliates.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree." + ] + }, + { + "cell_type": "markdown", + "id": "26b849fd", + "metadata": {}, + "source": [ + "# Introduction\n", + "Model conditioning techniques like pruning modify the weights of a Machine Learning model and in some cases allow significant speed-up of the inference execution, reduction of the memory footprint and reduction in the overall power consumption of the system. Assuming you can optimise your workload without loss in accuracy and you target an Arm® Ethos™ NPU or a GPU with a Neural Engine, you should consider pruning the neural network before compiling it in the to_edge_transform_and_lower stage." + ] + }, + { + "cell_type": "markdown", + "id": "9a7d6d97", + "metadata": {}, + "source": [ + "# Why apply model conditioning?\n", + "The Ethos-U hardware has a dedicated weight decoder to process the model weights. At the same time, the compiler arranges the weights into blocks and the blocks are then fed to the hardware weight decoder. As part of the block arrangement process, the compiler compresses sequences of zero weights and clusters of weights. To avoid any doubt, the compression by the compiler is lossless - to the same input tensor, irrespective of whether compression was applied or not, the output tensor from execution on the NPU will be the same. If the model you provide in the to_edge_transform_and_lower stage is optimised to have sequences of zero weights and/or clusters of the same weights, the compiler will be able to compress these weights very efficiently. The good compression would result in lower number of memory accesses by the NPU at runtime, which would mean that the MAC engines are not waiting on memory accesses resulting in better overall performance. In other words, if you have a memory bound model, you should consider pruning and clustering your neural network before lowering it in the to_edge_transform_and_lower stage.\n", + "\n", + "The Ethos-U85 hardware also has hardware support for 2:4 sparse weights - if you have 2:4 sparse weights, the MAC array will skip multiplications where the result will be 0. The 2:4 sparsity allow power savings for all configurations and provides a speed-up on compute-bound neural networks.\n", + "\n", + "Before we begin, make sure you are running the Jupyter notebook from the correct python virtual environment variable." + ] + }, + { + "cell_type": "markdown", + "id": "d6532247", + "metadata": {}, + "source": [ + "# Prerequisites\n", + "Let's import python the packages you will need to run through the jupyter notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8a191d7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torchvision import datasets, transforms\n", + "from torch import nn\n", + "import torch.nn.utils.prune as prune\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, Subset\n", + "import random\n", + "\n", + "from executorch.backends.arm.ethosu import EthosUPartitioner\n", + "from executorch.exir import (\n", + " EdgeCompileConfig,\n", + " ExecutorchBackendConfig,\n", + " to_edge_transform_and_lower,\n", + ")\n", + "from executorch.backends.arm.ethosu import EthosUCompileSpec\n", + "from executorch.backends.arm.quantizer import (\n", + " EthosUQuantizer,\n", + " get_symmetric_quantization_config,\n", + ")\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "from executorch.extension.export_util.utils import save_pte_program" + ] + }, + { + "cell_type": "markdown", + "id": "6af794bc", + "metadata": {}, + "source": [ + "# Model conditioning with PyTorch and deployment with ExecuTorch \n", + "We'll define a simple model with 3 back-to-back Linear layers. We will execute the model on the Ethos-U85 NPU, then we will prune the model and execute the pruned variant on the Ethos-U85 and compare the performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e37c2ce", + "metadata": {}, + "outputs": [], + "source": [ + "LR = 1e-3\n", + "NUM_EPOCHS = 1\n", + "BATCH_SIZE = 128\n", + "\n", + "# Data\n", + "transform = transforms.Compose([transforms.ToTensor()])\n", + "train_ds = datasets.MNIST(\"./data\", train=True, download=True, transform=transform)\n", + "test_ds = datasets.MNIST(\"./data\", train=False, transform=transform)\n", + "\n", + "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n", + "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + "\n", + "class Simple_NN(nn.Module): \n", + " def __init__(self):\n", + " super().__init__()\n", + " self.flatten = nn.Flatten()\n", + " self.fc1 = nn.Linear(28 * 28, 512)\n", + " self.fc2 = nn.Linear(512, 256)\n", + " self.fc3 = nn.Linear(256, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.flatten(x)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def prunable_parameters(self):\n", + " return (\n", + " (self.fc1, \"weight\"),\n", + " (self.fc2, \"weight\"),\n", + " (self.fc3, \"weight\"),\n", + " )\n", + "\n", + " def prune(self, pruning_method: prune.BasePruningMethod, amount: float = 0.1):\n", + " # reference https://pytorch.org/tutorials/intermediate/pruning_tutorial.html\n", + "\n", + " # produces a mask that is multiplied with the parameter\n", + " prune.global_unstructured(\n", + " self.prunable_parameters(),\n", + " pruning_method=pruning_method,\n", + " amount=amount,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "6db1e58d", + "metadata": {}, + "source": [ + "We define a simple model with 3 back-to-back linear layers. Linear is highly memory bound operation because every weight is read once only from the external memory. It is impossible to buffer the weights in memory(you usually have more weights in the external memory than space in the SARM) and reuse them for the computation. In comparison, in a convolution you usually have small filter sizes(e.g. 3x3 filter) which means you can buffer all the convolution weights in memory and reuse them for the computation. If your model or module within the model is composed entirely of Linear layers, the workload will be memory bound and pruning is likely to provide good speed-up.\n", + "\n", + "Next, let's define a simple function to train the network and a function to evaluate the accuracy of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "477312ae", + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "def train(model):\n", + " # The model is simple enough that we can train it on CPU\n", + " device = \"cpu\"\n", + " for epoch in range(NUM_EPOCHS):\n", + " # ---- Training ----\n", + " model.train()\n", + " opt = torch.optim.Adam(model.parameters(), lr=LR)\n", + " criterion = torch.nn.CrossEntropyLoss()\n", + " for step, (inp, out_real) in enumerate(train_loader):\n", + " inp, out_real = inp.to(device), out_real.to(device)\n", + " opt.zero_grad()\n", + " out_pred = model(inp)\n", + " loss = criterion(out_pred, out_real)\n", + " #print(f\"Loss: {loss.item():.4f}\")\n", + " loss.backward()\n", + " opt.step()\n", + "\n", + "def evaluate(model):\n", + " # ---- Evaluation ----\n", + " correct, total = 0, 0\n", + " with torch.no_grad():\n", + " for inp, out_real in test_loader:\n", + " out_pred = model(inp)\n", + " preds = out_pred.argmax(1)\n", + " correct += (preds == out_real).sum().item()\n", + " total += out_real.size(0)\n", + "\n", + " acc = 100 * correct / total\n", + " print(f\"Top 1 accuracy = {acc:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "id": "a4750eaf", + "metadata": {}, + "source": [ + "Let's instantiate the model and train it. In order to get reproducible results, we will fix the seed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc68a7d9", + "metadata": {}, + "outputs": [], + "source": [ + "SEED = 123\n", + "torch.manual_seed(SEED)\n", + "model = Simple_NN()\n", + "train(model)\n", + "print(\"Evaluate FP32 model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "9837d9ba", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy for the FP32 model.\n", + "\n", + "Next, we would like to apply post-training quantization with ExecuTorch and evaluate the accuracy of the quantized model. It is important to calibrate the quantized model on a few real samples from the MNIST dataset to get good quantization parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "855c542f", + "metadata": {}, + "outputs": [], + "source": [ + "# MNIST images are 28x28 in greyscale, hence the shape is 1x1x28x28\n", + "example_inputs = (torch.randn(1,1,28,28),)\n", + "exported_program = torch.export.export(model, example_inputs)\n", + "graph_module = exported_program.module(check_guards=False)\n", + "\n", + "# Create a compilation spec describing the target for configuring the quantizer\n", + "compile_spec = EthosUCompileSpec(\n", + " target=\"ethos-u85-128\",\n", + " system_config=\"Ethos_U85_SYS_Flash_High\",\n", + " memory_mode=\"Shared_Sram\",\n", + " extra_flags=[\"--output-format=raw\", \"--debug-force-regor --verbose-weights\"]\n", + " )\n", + "\n", + "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", + "quantizer = EthosUQuantizer(compile_spec)\n", + "operator_config = get_symmetric_quantization_config()\n", + "quantizer.set_global(operator_config)\n", + "\n", + "# Post training quantization, need a few example images to obtain good quantization parameters\n", + "subset_indices = random.sample(range(len(train_ds)), 50)\n", + "calibration_set = Subset(train_ds, subset_indices)\n", + "calibration_loader = DataLoader(calibration_set, shuffle=False)\n", + "\n", + "quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n", + "for batch_images,label in calibration_loader:\n", + " quantized_graph_module(*batch_images) # Calibrate the graph module with the example input\n", + "quantized_graph_module = convert_pt2e(quantized_graph_module)" + ] + }, + { + "cell_type": "markdown", + "id": "996faefd", + "metadata": {}, + "source": [ + "Next, let us evaluate the accuracy of the quantized model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63da2b30", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Accuracy of the quantized model\")\n", + "evaluate(quantized_graph_module)" + ] + }, + { + "cell_type": "markdown", + "id": "2ff3462c", + "metadata": {}, + "source": [ + "We maintain the 96% top1 accuracy for the quantized model. Next, let's compile the model for the Ethos-U backend. We will define a function `generate_pte` that calls `to_edge_transform_and_lower` and saves the pte file on device." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa8259f4", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_pte(quantized_exported_program,compile_spec,name):\n", + " # Create partitioner from compile spec\n", + " partitioner = EthosUPartitioner(compile_spec)\n", + "\n", + " # Lower the exported program to the Ethos-U backend\n", + " edge_program_manager = to_edge_transform_and_lower(\n", + " quantized_exported_program,\n", + " partitioner=[partitioner],\n", + " compile_config=EdgeCompileConfig(\n", + " _check_ir_validity=False,\n", + " ),\n", + " )\n", + "\n", + " # Convert edge program to executorch\n", + " executorch_program_manager = edge_program_manager.to_executorch(\n", + " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", + " )\n", + "\n", + " # Save pte file\n", + " save_pte_program(executorch_program_manager, f\"{name}.pte\")\n", + "\n", + "# Create a new exported program using the quantized_graph_module\n", + "quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)\n", + "generate_pte(quantized_exported_program,compile_spec,\"original_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b6cae04", + "metadata": {}, + "source": [ + "Note that as part of the compilation process in `to_edge_transform_and_lower`, we get Weight Compression information:\n", + "```\n", + "Original Weights Size 522.50 KiB\n", + "NPU Encoded Weights Size 507.44 KiB\n", + "```\n", + "In other words, the original Weights are 522KB and after compilation and encoding by the compiler, we get 507KB of weights that will be read by the NPU at runtime. Remember this is for the case when we've not applied pruning or clustering. This will generate original_model.pte file that we will deploy on device later on. \n", + "\n", + "Next, let's move on to prune the model and evaluate its accuracy. We have a lot of weights in the original network, so we will apply 95% pruning rate." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "493eed60", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Prune the model\")\n", + "model.prune(pruning_method=prune.L1Unstructured, amount=0.95)\n", + "print(\"Evaluate pruned model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "82460ba6", + "metadata": {}, + "source": [ + "We obtain 37% top1 accuracy for the pruned model. That can seem surprising at first sight, but remember that when we prune, we randomly set 95% of the weights to 0. It is normal to lose accuracy when applying pruning. We need to retrain the model in order to recover the accuracy we've lost from the pruning. We can do that easily by calling the train function one more time. Once we are done with the retraining, it is important to remove the parameters we've pruned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c816ad25", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Train the pruned model to recover the lost information\")\n", + "train(model)\n", + "# Remove the pruned parameters when we've retrained the model and recovered the lost accuracy\n", + "for a,b in model.prunable_parameters():\n", + " prune.remove(a, b)\n", + "\n", + "print(\"Evaluate pruned model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "fbb70d47", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy for the pruned workload so we have recovered the accuracy we've lost with the pruning. Let's quantize the pruned model, evaluate the accuracy of the int8 network and obtain a pte file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cdb0f59", + "metadata": {}, + "outputs": [], + "source": [ + "pruned_exported_program = torch.export.export(model, example_inputs)\n", + "pruned_graph_module = pruned_exported_program.module(check_guards=False)\n", + "quantized_pruned_graph_module = prepare_pt2e(pruned_graph_module, quantizer)\n", + "for batch_images,label in calibration_loader:\n", + " quantized_pruned_graph_module(*batch_images) # Calibrate the graph module with the example input\n", + "quantized_pruned_graph_module = convert_pt2e(quantized_pruned_graph_module)\n", + "print(\"Accuracy of the pruned quantized model\")\n", + "evaluate(quantized_pruned_graph_module)\n", + "\n", + "quantized_ep_pruned = torch.export.export(quantized_pruned_graph_module, example_inputs)\n", + "generate_pte(quantized_ep_pruned,compile_spec,\"pruned_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "4263714e", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy of the quantized pruned model. What is interesting is that this time, the NPU encoded weights size shrank considerably:\n", + "```\n", + "Original Weights Size 522.50 KiB\n", + "NPU Encoded Weights Size 46.12 KiB\n", + "```\n", + "In other words, we are now solving the MNIST classification problem with just 46KB of encoded weights. This is a significant reduction from the 507KB we had in the original model.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "562fdb16", + "metadata": {}, + "source": [ + "# NPU performance\n", + "In the sections above, we generated two pte files - one pte for the original model and another pte for the pruned model. These models perform very similarly in terms of accuracy. Let's run both of these models on the NPU and analyse the performance at runtime.\n", + "\n", + "# Performance of the original model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bdd91dc", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "# Ensure the arm-none-eabi-gcc toolchain and FVP:s are available on $PATH\n", + "source ethos-u-scratch/setup_path.sh\n", + "\n", + "# Build executorch libraries cross-compiled for arm baremetal to executorch/cmake-out-arm\n", + "cmake --preset arm-baremetal \\\n", + "-DCMAKE_BUILD_TYPE=Release \\\n", + "-B../../cmake-out-arm ../..\n", + "cmake --build ../../cmake-out-arm --target install -j$(nproc) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "756ab779", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=original_model.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u85-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U85_SYS_DRAM_Mid \\\n", + " -Bethos_u_original_model \\\n", + " executor_runner\n", + "cmake --build ethos_u_original_model -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a525a09", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "# Run the pruned model\n", + "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_original_model/arm_executor_runner --target=ethos-u85-128" + ] + }, + { + "cell_type": "markdown", + "id": "23ebdc46", + "metadata": {}, + "source": [ + "We obtain a total of 99k NPU Active cycles. The MAC engines of the NPU are active during 8k cycles and the Weight Decoder is active during 74k NPU cycles. It's worth noting that the data flow in the Ethos-U is pipelined. In other words, the MAC array and the Weight Decoder are working at the same time. Having a total of 99k NPU cycles and only 8k Active MAC cycles and 74k of Weight Decoder active cycles means that the NPU is spending most of the time decoding weights and the MAC array is underutilized. Pruning is designed to alleviate that bottleneck. Let's analyse the performance of the pruned workload.\n", + "\n", + "# Performance of the pruned model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7c09926", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "\n", + "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=pruned_model.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u85-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U85_SYS_DRAM_Mid \\\n", + " -Bethos_u_pruned_model \\\n", + " executor_runner\n", + "cmake --build ethos_u_pruned_model -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "891947f7", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "# Run the pruned model\n", + "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_pruned_model/arm_executor_runner --target=ethos-u85-128" + ] + }, + { + "cell_type": "markdown", + "id": "e55ae929", + "metadata": {}, + "source": [ + "On the pruned model, the inference completes in 22k NPU cycles. The NPU still performs 8k MACs, but this time the number of cycles when the weight decoder is active has dropped to to 17k cycles. \n", + "It's also worth noting that the size of the pte file has been reduced significantly - from 518 KB of the original model to 57KB of the pruned workload. " + ] + }, + { + "cell_type": "markdown", + "id": "d934fe41", + "metadata": {}, + "source": [ + "# Conclusion\n", + "We defined a simple model to solve the MNIST dataset. The model is using Linear layers and is heavily memory-bound on the external memory. We pruned the model and obtain similar int8 accuracy between the original workload and the pruned counterpart. Let us put the results from the runtime in a table and draw a few conclusions: \n", + "\n", + "| Model |NPU_ACTIVE cycles | NPU Encoded Weight Size | Weight Decoder Active Cycles | External memory beats read | Size of the pte file |\n", + "| ----------------------------------------|----------------- | ------------------------- | -----------------------------|---------------------------------|-----------------------|\n", + "| Original model | 97k | 506 KB | 74k | 32k | 517 KB |\n", + "| Pruned model | 22k | 46 KB | 8k | 3k | 57 KB |\n", + "\n", + "For the pruned network, we obtain a significant uplift - over 3x improvement in the inference speed and a drastic reduction in the number of cycles when the Weight Decoder is active. The NPU will consume lower power and the size of the pruned model that we save on-device is significantly smaller compared to the original network." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv_py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 191be3fc3fe..10f2a259800 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -61,7 +61,7 @@ function help() { echo " --output= Target build output folder Default: ${output_folder}" echo " --bundleio Create Bundled pte using Devtools BundelIO with Input/RefOutput included" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" - echo " --build_type= Build with Release, Debug or RelWithDebInfo, default is ${build_type}" + echo " --build_type= Build with Release, Debug, RelWithDebInfo or UndefinedSanitizer, default is ${build_type}" echo " --extra_build_flags= Extra flags to pass to cmake like -DET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=60000 Default: none " echo " --build_only Only build, don't run" echo " --toolchain= Ethos-U: Toolchain can be specified (e.g. bare metal as arm-none-eabi-gcc or zephyr as arm-zephyr-eabi-gcc Default: ${toolchain}" @@ -321,7 +321,8 @@ for i in "${!test_model[@]}"; do set -x backends/arm/scripts/build_executor_runner_vkml.sh --build_type=${build_type} \ --extra_build_flags="${extra_build_flags}" \ - --output="${output_folder}" + --output="${output_folder}" \ + ${bundleio_flag} if [ "$build_only" = false ] ; then backends/arm/scripts/run_vkml.sh --model=${pte_file} --build_path=${output_folder} fi diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index a11b4a2eebd..ed7d90c8b42 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -26,6 +26,7 @@ enable_model_converter=0 # model-converter tool for VGF output enable_vgf_lib=0 # vgf reader - runtime backend dependency enable_emulation_layer=0 # Vulkan layer driver - emulates Vulkan ML extensions enable_vulkan_sdk=0 # Download and export Vulkan SDK required by emulation layer +enable_mlsdk_pip_install=0 # This is a temporary option that will soon be the default # Figure out if setup.sh was called or sourced and save it into "is_script_sourced" (return 0 2>/dev/null) && is_script_sourced=1 || is_script_sourced=0 @@ -51,6 +52,7 @@ OPTION_LIST=( "--enable-emulation-layer Enable MLSDK Vulkan emulation layer" "--disable-ethos-u-deps Do not setup what is needed for Ethos-U" "--enable-mlsdk-deps Setup what is needed for MLSDK" + "--install-mlsdk-deps-with-pip Use MLSDK PyPi package instead of building from source" "--mlsdk-manifest-url URL to the MLSDK manifest for vulkan." "--help Display help" ) @@ -140,6 +142,10 @@ function check_options() { enable_vela=0 shift ;; + --install-mlsdk-deps-with-pip) + enable_mlsdk_pip_install=1 + shift + ;; --enable-mlsdk-deps) enable_model_converter=1 enable_vgf_lib=1 @@ -176,12 +182,22 @@ function setup_ethos_u_tools() { CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 pip install --no-dependencies -r $et_dir/backends/arm/requirements-arm-ethos-u.txt } +function setup_mlsdk_dependencies() { + log_step "mlsdk" "Installing MLSDK dependencies from pip" + pip install -r $et_dir/backends/arm/requirements-arm-vgf.txt +} + function create_setup_path(){ cd "${root_dir}" clear_setup_path log_step "path" "Generating setup path scripts at ${setup_path_script}" + local use_mlsdk_pip=0 + if use_mlsdk_pip_package; then + use_mlsdk_pip=1 + fi + if [[ "${enable_fvps}" -eq 1 ]]; then setup_path_fvp fi @@ -194,19 +210,48 @@ function create_setup_path(){ setup_path_vulkan fi - if [[ "${enable_model_converter}" -eq 1 ]]; then + if [[ "${enable_model_converter}" -eq 1 && "${use_mlsdk_pip}" -eq 0 ]]; then setup_path_model_converter fi - if [[ "${enable_vgf_lib}" -eq 1 ]]; then + if [[ "${enable_vgf_lib}" -eq 1 && "${use_mlsdk_pip}" -eq 0 ]]; then setup_path_vgf_lib fi if [[ "${enable_emulation_layer}" -eq 1 ]]; then - setup_path_emulation_layer + if [[ "${use_mlsdk_pip}" -eq 0 ]]; then + setup_path_emulation_layer + else + setup_path_emulation_layer_from_pip + fi + fi + + log_step "path" "Update PATH by sourcing ${setup_path_script}.{sh|fish}" +} + +function use_mlsdk_pip_package() { + os=$(uname -s) + arch=$(uname -m) + + if [[ "${enable_mlsdk_pip_install}" -eq 0 ]]; then + return 1 + fi + + if [[ "$os" == "Darwin" ]]; then + if [[ "${enable_mlsdk_pip_install}" -eq 1 ]]; then + log_step "mlsdk" "[error] MLSDK pip install not yet supported on MacOS" + exit 1 + fi + fi + + if [[ "$arch" == "arm64" || "$arch" == "aarch64" ]]; then + if [[ "${enable_mlsdk_pip_install}" -eq 1 ]]; then + log_step "mlsdk" "[error] MLSDK pip install not yet supported on aarch64" + exit 1 + fi fi - log_step "path" "Update PATH by sourcing ${setup_path_script}.{sh|fish}" + return 0 } @@ -224,6 +269,7 @@ if [[ $is_script_sourced -eq 0 ]]; then source $et_dir/backends/arm/scripts/fvp_utils.sh source $et_dir/backends/arm/scripts/toolchain_utils.sh source $et_dir/backends/arm/scripts/vulkan_utils.sh + source $et_dir/backends/arm/scripts/mlsdk_utils.sh log_step "main" "Checking platform and OS" check_platform_support @@ -239,8 +285,12 @@ if [[ $is_script_sourced -eq 0 ]]; then mlsdk_manifest_dir="${root_dir}/${mlsdk_manifest_dir}" fi - log_step "options" "root=${root_dir}, target-toolchain=${target_toolchain:-}, mlsdk-dir=${mlsdk_manifest_dir}" - log_step "options" "ethos-u: fvps=${enable_fvps}, toolchain=${enable_baremetal_toolchain}, vela=${enable_vela} | mlsdk: model-converter=${enable_model_converter}, vgf-lib=${enable_vgf_lib}, emu-layer=${enable_emulation_layer}, vulkan-sdk=${enable_vulkan_sdk}" + log_step "options" \ + "root=${root_dir}, target-toolchain=${target_toolchain:-}, mlsdk-dir=${mlsdk_manifest_dir}" + log_step "options" \ + "ethos-u: fvps=${enable_fvps}, toolchain=${enable_baremetal_toolchain}, vela=${enable_vela} | " \ + "mlsdk: model-converter=${enable_model_converter}, vgf-lib=${enable_vgf_lib}, " \ + "emu-layer=${enable_emulation_layer}, vulkan-sdk=${enable_vulkan_sdk}" # Setup toolchain if [[ "${enable_baremetal_toolchain}" -eq 1 ]]; then @@ -267,13 +317,18 @@ if [[ $is_script_sourced -eq 0 ]]; then if [[ "${enable_model_converter}" -eq 1 || \ "${enable_vgf_lib}" -eq 1 || \ "${enable_emulation_layer}" -eq 1 ]]; then - log_step "mlsdk" "Configuring MLSDK components (model-converter=${enable_model_converter}, vgf-lib=${enable_vgf_lib}, emu-layer=${enable_emulation_layer})" - source $et_dir/backends/arm/scripts/mlsdk_utils.sh - setup_mlsdk "${root_dir}" \ - "${mlsdk_manifest_dir}" \ - "${enable_model_converter}" \ - "${enable_vgf_lib}" \ - "${enable_emulation_layer}" + log_step "mlsdk" "Configuring MLSDK components (model-converter=${enable_model_converter}, " \ + "vgf-lib=${enable_vgf_lib}, emu-layer=${enable_emulation_layer})" + if use_mlsdk_pip_package; then + setup_mlsdk_dependencies + else + log_step "mlsdk" "Installing MLSDK dependencies from source" + setup_mlsdk ${root_dir} \ + ${mlsdk_manifest_dir} \ + ${enable_model_converter} \ + ${enable_vgf_lib} \ + ${enable_emulation_layer} + fi fi # Create the setup_path.sh used to create the PATH variable for shell diff --git a/examples/arm/ubsan/CMakeLists.txt b/examples/arm/ubsan/CMakeLists.txt new file mode 100644 index 00000000000..8d5d23211b1 --- /dev/null +++ b/examples/arm/ubsan/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +add_library(executorch_ubsan STATIC ubsan_runtime.c) + +target_compile_features(executorch_ubsan PRIVATE c_std_11) + +target_compile_options(executorch_ubsan PRIVATE -fno-sanitize=undefined) + +set_target_properties(executorch_ubsan PROPERTIES OUTPUT_NAME "ubsan") + +install( + TARGETS executorch_ubsan + EXPORT ExecuTorchTargets + ARCHIVE DESTINATION lib +) diff --git a/examples/arm/ubsan/ubsan_runtime.c b/examples/arm/ubsan/ubsan_runtime.c new file mode 100644 index 00000000000..62f411073ba --- /dev/null +++ b/examples/arm/ubsan/ubsan_runtime.c @@ -0,0 +1,488 @@ +/* Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#ifndef UBSAN_RUNTIME_PREFIX +#define UBSAN_RUNTIME_PREFIX "[UBSAN] " +#endif + +typedef struct { + const char* filename; + uint32_t line; + uint32_t column; +} __ubsan_source_location; + +typedef struct { + uint16_t type_kind; + uint16_t type_info; + char type_name[]; +} __ubsan_type_descriptor; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; +} __ubsan_overflow_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* lhs_type; + const __ubsan_type_descriptor* rhs_type; +} __ubsan_shift_out_of_bounds_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* array_type; + const __ubsan_type_descriptor* index_type; +} __ubsan_out_of_bounds_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; + uint8_t log_alignment; + uint8_t type_check_kind; +} __ubsan_type_mismatch_data_v1; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; +} __ubsan_vla_bound_data; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location attr_location; +} __ubsan_nonnull_return_data_v1; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location attr_location; + uint8_t arg_index; +} __ubsan_nullability_arg_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* from_type; + const __ubsan_type_descriptor* to_type; +} __ubsan_float_cast_overflow_data; + +typedef struct { + __ubsan_source_location location; + const __ubsan_type_descriptor* type; +} __ubsan_invalid_value_data; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location attr_location; + uint32_t arg_index; +} __ubsan_nonnull_arg_data; + +typedef struct { + __ubsan_source_location location; +} __ubsan_pointer_overflow_data; + +typedef struct { + __ubsan_source_location location; + __ubsan_source_location assumption_location; + uint64_t alignment; + uint8_t type_check_kind; +} __ubsan_alignment_assumption_data; + +static const char* ubsan_get_type_name(const __ubsan_type_descriptor* type) { + if (!type) { + return ""; + } + return type->type_name; +} + +static const char* ubsan_type_check_kind_string(uint8_t kind) { + switch (kind) { + case 0: + return "load of"; + case 1: + return "store to"; + case 2: + return "reference binding to"; + case 3: + return "member access within"; + case 4: + return "member call on"; + case 5: + return "constructor call for"; + case 6: + return "downcast of"; + case 7: + return "downcast of"; + case 8: + return "upcast of"; + case 9: + return "cast to virtual base of"; + default: + return "use of"; + } +} + +static uintptr_t ubsan_ptr_value(const void* ptr) { + return (uintptr_t)ptr; +} + +static void ubsan_abort(void) { +#if defined(__GNUC__) + __builtin_trap(); +#else + abort(); +#endif + while (1) { + } +} + +static void ubsan_print_location(const __ubsan_source_location* loc) { + if (!loc || !loc->filename) { + printf(UBSAN_RUNTIME_PREFIX "unknown location: "); + return; + } + printf(UBSAN_RUNTIME_PREFIX "%s:%u:%u: ", loc->filename, loc->line, + loc->column); +} + +static void ubsan_report_with_message(const __ubsan_source_location* loc, + const char* message) { + ubsan_print_location(loc); + printf("%s\n", message); + fflush(stdout); + ubsan_abort(); +} + +static void ubsan_report_overflow(const __ubsan_overflow_data* data, + const char* op, + uintptr_t lhs, + uintptr_t rhs) { + const char* type_name = ubsan_get_type_name(data->type); + char message[256]; + snprintf( + message, + sizeof(message), + "%s on type '%s' (lhs=0x%08" PRIxPTR ", rhs=0x%08" PRIxPTR ")", + op, + type_name, + lhs, + rhs); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_add_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "addition overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_sub_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "subtraction overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_mul_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "multiplication overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_negate_overflow(__ubsan_overflow_data* data, void* value) { + ubsan_report_overflow( + data, + "negation overflow", + ubsan_ptr_value(value), + 0); +} + +void __ubsan_handle_divrem_overflow(__ubsan_overflow_data* data, void* lhs, + void* rhs) { + ubsan_report_overflow( + data, + "division remainder overflow", + ubsan_ptr_value(lhs), + ubsan_ptr_value(rhs)); +} + +void __ubsan_handle_shift_out_of_bounds(__ubsan_shift_out_of_bounds_data* data, + void* lhs, void* rhs) { + const char* lhs_type = ubsan_get_type_name(data->lhs_type); + const char* rhs_type = ubsan_get_type_name(data->rhs_type); + uintptr_t lhs_val = ubsan_ptr_value(lhs); + uintptr_t rhs_val = ubsan_ptr_value(rhs); + char message[256]; + snprintf( + message, + sizeof(message), + "shift out of bounds (lhs=0x%08" PRIxPTR " of type '%s', rhs=0x%08" PRIxPTR + " of type '%s')", + lhs_val, + lhs_type, + rhs_val, + rhs_type); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_out_of_bounds(__ubsan_out_of_bounds_data* data, + void* index) { + uintptr_t idx_val = ubsan_ptr_value(index); + const char* idx_type = ubsan_get_type_name(data->index_type); + const char* array_type = ubsan_get_type_name(data->array_type); + char message[256]; + snprintf( + message, + sizeof(message), + "index out of bounds (index=0x%08" PRIxPTR " of type '%s' on array '%s')", + idx_val, + idx_type, + array_type); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_type_mismatch_v1(__ubsan_type_mismatch_data_v1* data, + void* ptr) { + uintptr_t address = (uintptr_t)ptr; + size_t alignment = + (data->log_alignment < (sizeof(size_t) * 8)) + ? ((size_t)1 << data->log_alignment) + : 0; + const char* type_name = ubsan_get_type_name(data->type); + const char* check_desc = ubsan_type_check_kind_string(data->type_check_kind); + + char message[256]; + if (address == 0) { + snprintf( + message, + sizeof(message), + "%s null pointer of type '%s'", + check_desc, + type_name); + } else if (alignment && (address & (alignment - 1))) { + snprintf( + message, + sizeof(message), + "%s misaligned address 0x%08" PRIxPTR " for type '%s' (alignment %zu)", + check_desc, + address, + type_name, + alignment); + } else { + snprintf( + message, + sizeof(message), + "%s address 0x%08" PRIxPTR " with insufficient alignment for type '%s'", + check_desc, + address, + type_name); + } + + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_vla_bound_not_positive(__ubsan_vla_bound_data* data, + void* bound) { + uintptr_t bound_val = ubsan_ptr_value(bound); + char message[256]; + snprintf( + message, + sizeof(message), + "variable length array bound (%" PRIuPTR ") is not positive", + (uintptr_t)bound_val); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_load_invalid_value(__ubsan_invalid_value_data* data, + void* pointer) { + uintptr_t addr = ubsan_ptr_value(pointer); + const char* type_name = ubsan_get_type_name(data->type); + char message[256]; + snprintf( + message, + sizeof(message), + "load of invalid value at 0x%08" PRIxPTR " for type '%s'", + addr, + type_name); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nonnull_return_v1(__ubsan_nonnull_return_data_v1* data, + __ubsan_source_location* where) { + (void)where; // Some toolchains leave this null; attr_location is reliable. + char message[256]; + if (data->attr_location.filename) { + snprintf( + message, + sizeof(message), + "null pointer returned from function marked 'returns_nonnull' " + "(attribute at %s:%u:%u)", + data->attr_location.filename, + data->attr_location.line, + data->attr_location.column); + } else { + snprintf( + message, + sizeof(message), + "null pointer returned from function marked 'returns_nonnull'"); + } + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nullability_return_v1( + __ubsan_nonnull_return_data_v1* data, __ubsan_source_location* where) { + (void)where; // Some toolchains leave this null; attr_location is reliable. + char message[256]; + snprintf( + message, + sizeof(message), + "null returned from non-null return (attribute at %s:%u:%u)", + data->attr_location.filename ? data->attr_location.filename : "", + data->attr_location.line, + data->attr_location.column); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nullability_arg_v1(__ubsan_nullability_arg_data* data, + __ubsan_source_location* where) { + (void)where; // Some toolchains leave this null; attr_location is reliable. + char message[256]; + snprintf( + message, + sizeof(message), + "null passed to non-null argument #%u (attribute at %s:%u:%u)", + data->arg_index, + data->attr_location.filename ? data->attr_location.filename : "", + data->attr_location.line, + data->attr_location.column); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_nonnull_arg(__ubsan_nonnull_arg_data* data) { + char message[256]; + snprintf( + message, + sizeof(message), + "null pointer passed to argument marked 'nonnull' (argument #%u, attribute at %s:%u:%u)", + data->arg_index, + data->attr_location.filename ? data->attr_location.filename : "", + data->attr_location.line, + data->attr_location.column); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_float_cast_overflow( + __ubsan_float_cast_overflow_data* data, void* from) { + uintptr_t raw = ubsan_ptr_value(from); + const char* from_type = ubsan_get_type_name(data->from_type); + const char* to_type = ubsan_get_type_name(data->to_type); + char message[256]; + snprintf( + message, + sizeof(message), + "floating point cast overflow (value bits=0x%08" PRIxPTR + ", from '%s' to '%s')", + raw, + from_type, + to_type); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_pointer_overflow(__ubsan_pointer_overflow_data* data, + void* base, void* result) { + uintptr_t base_val = ubsan_ptr_value(base); + uintptr_t result_val = ubsan_ptr_value(result); + char message[256]; + snprintf( + message, + sizeof(message), + "pointer overflow (base=0x%08" PRIxPTR ", result=0x%08" PRIxPTR ")", + base_val, + result_val); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_alignment_assumption( + __ubsan_alignment_assumption_data* data, void* pointer, + void* alignment, void* offset) { + uintptr_t ptr_val = ubsan_ptr_value(pointer); + uintptr_t align_val = ubsan_ptr_value(alignment); + uintptr_t offset_val = ubsan_ptr_value(offset); + char message[256]; + snprintf( + message, + sizeof(message), + "alignment assumption violated (ptr=0x%08" PRIxPTR ", alignment=%" PRIuPTR + ", offset=%" PRIuPTR ", required alignment=%" PRIu64 ")", + ptr_val, + align_val, + offset_val, + (unsigned long long)data->alignment); + ubsan_report_with_message(&data->location, message); +} + +void __ubsan_handle_builtin_unreachable(__ubsan_source_location* location) { + ubsan_report_with_message(location, "execution reached an unreachable point"); +} + +void __ubsan_handle_missing_return(__ubsan_source_location* location) { + ubsan_report_with_message(location, + "control reached end of void function without " + "returning"); +} + +void __ubsan_handle_invalid_builtin(__ubsan_source_location* location) { + ubsan_report_with_message(location, "invalid builtin usage"); +} + +void __ubsan_handle_cfi_check_fail(__ubsan_source_location* location, + void* data, void* vtable) { + uintptr_t type_hash = ubsan_ptr_value(data); + uintptr_t vtable_ptr = ubsan_ptr_value(vtable); + char message[256]; + snprintf( + message, + sizeof(message), + "control-flow integrity check failed (type hash=0x%08" PRIxPTR + ", vtable=0x%08" PRIxPTR ")", + type_hash, + vtable_ptr); + ubsan_report_with_message(location, message); +} + +void __ubsan_handle_cfi_check_fail_abort(__ubsan_source_location* location, + void* data, void* vtable) { + __ubsan_handle_cfi_check_fail(location, data, vtable); +} + +void __ubsan_handle_dynamic_type_cache_miss(void* data, void* ptr) { + uintptr_t type_hash = ubsan_ptr_value(data); + uintptr_t object_ptr = ubsan_ptr_value(ptr); + printf( + UBSAN_RUNTIME_PREFIX + "dynamic type cache miss (type hash=0x%08" PRIxPTR ", object=0x%08" PRIxPTR + ")\n", + type_hash, + object_ptr); + fflush(stdout); + ubsan_abort(); +} + +void __ubsan_on_error(void) { + printf(UBSAN_RUNTIME_PREFIX "runtime error detected\n"); + fflush(stdout); + ubsan_abort(); +} diff --git a/examples/arm/vgf_minimal_example.ipynb b/examples/arm/vgf_minimal_example.ipynb index 1f8e0a61601..c14430e3a2f 100644 --- a/examples/arm/vgf_minimal_example.ipynb +++ b/examples/arm/vgf_minimal_example.ipynb @@ -116,6 +116,7 @@ " VgfQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", + "from executorch.backends.arm.vgf import VgfCompileSpec\n", "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "# Create a compilation spec describing the target for configuring the quantizer\n", diff --git a/examples/arm/visualize.py b/examples/arm/visualize.py index 9f94f871186..cf50765a125 100644 --- a/examples/arm/visualize.py +++ b/examples/arm/visualize.py @@ -95,6 +95,11 @@ def is_end_of_command(qread_offset: int, end_idx: int) -> bool: qread_offset = 4 * int(event["args"]["qread"]) + while (cmd_index + chain_len <= queue_df_len - 1) and queue_df.iloc[ + cmd_index + chain_len + ]["scheduled_id"] in sub_ops: + chain_len += 1 + end_idx = cmd_index + chain_len if is_end_of_command(qread_offset, end_idx): end_ts = int(event["ts"]) - 1 @@ -102,12 +107,8 @@ def is_end_of_command(qread_offset: int, end_idx: int) -> bool: end_ts - start_ts, ] start_ts = end_ts - cmd_index += chain_len + cmd_index = end_idx chain_len = 1 - while (cmd_index + chain_len <= queue_df_len - 1) and queue_df.iloc[ - cmd_index + chain_len - ]["scheduled_id"] in sub_ops: - chain_len += 1 Agg = Union[str, Callable[[pd.Series], Any]] diff --git a/examples/demo-apps/react-native/rnllama/package.json b/examples/demo-apps/react-native/rnllama/package.json index 08d5d65fd30..3286dcbe575 100644 --- a/examples/demo-apps/react-native/rnllama/package.json +++ b/examples/demo-apps/react-native/rnllama/package.json @@ -54,7 +54,8 @@ }, "private": true, "resolutions": { - "cookie": ">=0.7.0" + "cookie": ">=0.7.0", + "glob": "^10.5.0" }, "packageManager": "yarn@1.22.22+sha512.a6b2f7906b721bba3d67d4aff083df04dad64c399707841b7acf00f6b133b7ac24255f2652fa22ae3534329dc6180534e98d17432037ff6fd140556e2bb3137e" } diff --git a/examples/demo-apps/react-native/rnllama/yarn.lock b/examples/demo-apps/react-native/rnllama/yarn.lock index ef1359a2eca..c7d29446dc2 100644 --- a/examples/demo-apps/react-native/rnllama/yarn.lock +++ b/examples/demo-apps/react-native/rnllama/yarn.lock @@ -3661,11 +3661,6 @@ fs-minipass@^3.0.0: dependencies: minipass "^7.0.3" -fs.realpath@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" - integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw== - fsevents@^2.3.2: version "2.3.3" resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.3.tgz#cac6407785d03675a2a5e1a5305c697b347d90d6" @@ -3731,10 +3726,10 @@ glob-parent@^5.1.2: dependencies: is-glob "^4.0.1" -glob@^10.2.2, glob@^10.3.10, glob@^10.4.2: - version "10.4.5" - resolved "https://registry.yarnpkg.com/glob/-/glob-10.4.5.tgz#f4d9f0b90ffdbab09c9d77f5f29b4262517b0956" - integrity sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg== +glob@^10.2.2, glob@^10.3.10, glob@^10.4.2, glob@^10.5.0, glob@^7.1.1, glob@^7.1.3, glob@^7.1.4: + version "10.5.0" + resolved "https://registry.yarnpkg.com/glob/-/glob-10.5.0.tgz#8ec0355919cd3338c28428a23d4f24ecc5fe738c" + integrity sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg== dependencies: foreground-child "^3.1.0" jackspeak "^3.1.2" @@ -3743,18 +3738,6 @@ glob@^10.2.2, glob@^10.3.10, glob@^10.4.2: package-json-from-dist "^1.0.0" path-scurry "^1.11.1" -glob@^7.1.1, glob@^7.1.3, glob@^7.1.4: - version "7.2.3" - resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b" - integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q== - dependencies: - fs.realpath "^1.0.0" - inflight "^1.0.4" - inherits "2" - minimatch "^3.1.1" - once "^1.3.0" - path-is-absolute "^1.0.0" - globals@^11.1.0: version "11.12.0" resolved "https://registry.yarnpkg.com/globals/-/globals-11.12.0.tgz#ab8795338868a0babd8525758018c2a7eb95c42e" @@ -3975,15 +3958,7 @@ indent-string@^4.0.0: resolved "https://registry.yarnpkg.com/indent-string/-/indent-string-4.0.0.tgz#624f8f4497d619b2d9768531d58f4122854d7251" integrity sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg== -inflight@^1.0.4: - version "1.0.6" - resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" - integrity sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA== - dependencies: - once "^1.3.0" - wrappy "1" - -inherits@2, inherits@2.0.4, inherits@^2.0.3, inherits@~2.0.3: +inherits@2.0.4, inherits@^2.0.3, inherits@~2.0.3: version "2.0.4" resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== @@ -5244,7 +5219,7 @@ mimic-fn@^2.1.0: resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-2.1.0.tgz#7ed2c2ccccaf84d3ffcb7a69b57711fc2083401b" integrity sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg== -minimatch@^3.0.2, minimatch@^3.0.4, minimatch@^3.1.1: +minimatch@^3.0.2, minimatch@^3.0.4: version "3.1.2" resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== @@ -5489,7 +5464,7 @@ on-headers@~1.0.2: resolved "https://registry.yarnpkg.com/on-headers/-/on-headers-1.0.2.tgz#772b0ae6aaa525c399e489adfad90c403eb3c28f" integrity sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA== -once@^1.3.0, once@^1.3.1, once@^1.4.0: +once@^1.3.1, once@^1.4.0: version "1.4.0" resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1" integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w== @@ -5656,11 +5631,6 @@ path-exists@^4.0.0: resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== -path-is-absolute@^1.0.0: - version "1.0.1" - resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" - integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg== - path-key@^2.0.0, path-key@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40" diff --git a/examples/devtools/CMakeLists.txt b/examples/devtools/CMakeLists.txt index 355ff375361..f541f70f86d 100644 --- a/examples/devtools/CMakeLists.txt +++ b/examples/devtools/CMakeLists.txt @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,7 +48,9 @@ find_package( ) add_executable(example_runner example_runner/example_runner.cpp) -target_compile_options(executorch INTERFACE -DET_EVENT_TRACER_ENABLED) +target_compile_options( + executorch INTERFACE -DET_EVENT_TRACER_ENABLED -DET_BUNDLE_IO_ENABLED +) target_include_directories( etdump INTERFACE ${CMAKE_CURRENT_BINARY_DIR}/../../devtools/include diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 45abfd8f89d..6a6c4ff1875 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -40,6 +40,7 @@ class Model(str, Enum): Phi4Mini = "phi_4_mini" SmolLM2 = "smollm2" DeiTTiny = "deit_tiny" + Sdpa = "sdpa" def __str__(self) -> str: return self.value @@ -89,6 +90,7 @@ def __str__(self) -> str: str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"), str(Model.SmolLM2): ("smollm2", "SmolLM2Model"), str(Model.DeiTTiny): ("deit_tiny", "DeiTTinyModel"), + str(Model.Sdpa): ("toy_model", "SdpaModule"), } __all__ = [ diff --git a/examples/models/deit_tiny/model.py b/examples/models/deit_tiny/model.py index e92167bfbb4..e1db416d636 100644 --- a/examples/models/deit_tiny/model.py +++ b/examples/models/deit_tiny/model.py @@ -6,7 +6,6 @@ import logging import torch -from torchvision import transforms try: import timm # type: ignore @@ -15,8 +14,6 @@ "timm package is required for builtin 'deit_tiny'. Install timm." ) from e -from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD - from ..model_base import EagerModelBase @@ -27,16 +24,13 @@ def __init__(self): # type: ignore[override] def get_eager_model(self) -> torch.nn.Module: # type: ignore[override] logging.info("Loading timm deit_tiny_patch16_224 model") - model = timm.models.deit.deit_tiny_patch16_224(pretrained=False) - model.eval() + model = timm.models.deit.deit_tiny_patch16_224(pretrained=True) logging.info("Loaded timm deit_tiny_patch16_224 model") return model def get_example_inputs(self): # type: ignore[override] - normalize = transforms.Normalize( - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD - ) - return (normalize(torch.rand((1, 3, 224, 224))),) + input_shape = (1, 3, 224, 224) + return (torch.randn(input_shape),) __all__ = ["DeiTTinyModel"] diff --git a/examples/models/gemma3/CMakeLists.txt b/examples/models/gemma3/CMakeLists.txt index 0be346d70f2..d228ca53c46 100644 --- a/examples/models/gemma3/CMakeLists.txt +++ b/examples/models/gemma3/CMakeLists.txt @@ -102,8 +102,8 @@ list( # Link CUDA backend if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) - list(APPEND link_libraries aoti_cuda) - executorch_target_link_options_shared_lib(aoti_cuda) + list(APPEND link_libraries aoti_cuda_backend) + executorch_target_link_options_shared_lib(aoti_cuda_backend) endif() # Add tokenizers diff --git a/examples/models/gemma3/CMakePresets.json b/examples/models/gemma3/CMakePresets.json new file mode 100644 index 00000000000..dcfeceba1cd --- /dev/null +++ b/examples/models/gemma3/CMakePresets.json @@ -0,0 +1,76 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "gemma3-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/gemma3", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "gemma3-cpu", + "displayName": "Gemma3 runner (CPU)", + "inherits": ["gemma3-base"] + }, + { + "name": "gemma3-cuda", + "displayName": "Gemma3 runner (CUDA)", + "inherits": ["gemma3-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + } + ], + "buildPresets": [ + { + "name": "gemma3-cpu", + "displayName": "Build Gemma3 runner (CPU)", + "configurePreset": "gemma3-cpu", + "targets": ["gemma3_e2e_runner"] + }, + { + "name": "gemma3-cuda", + "displayName": "Build Gemma3 runner (CUDA)", + "configurePreset": "gemma3-cuda", + "targets": ["gemma3_e2e_runner"] + } + ], + "workflowPresets": [ + { + "name": "gemma3-cpu", + "displayName": "Configure and build Gemma3 runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "gemma3-cpu" + }, + { + "type": "build", + "name": "gemma3-cpu" + } + ] + }, + { + "name": "gemma3-cuda", + "displayName": "Configure and build Gemma3 runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "gemma3-cuda" + }, + { + "type": "build", + "name": "gemma3-cuda" + } + ] + } + ] +} diff --git a/examples/models/gemma3/README.md b/examples/models/gemma3/README.md index e24ebdf1a09..9d36ae2b625 100644 --- a/examples/models/gemma3/README.md +++ b/examples/models/gemma3/README.md @@ -78,23 +78,11 @@ Ensure you have a CUDA-capable GPU and CUDA toolkit installed on your system. ### Building for CUDA ```bash -# Install ExecuTorch. -./install_executorch.sh - -# Build the multimodal runner with CUDA -cmake --preset llm \ - -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out -S. -cmake --build cmake-out -j$(nproc) --target install --config Release - -# Build the Gemma3 runner -cmake -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -Sexamples/models/gemma3 \ - -Bcmake-out/examples/models/gemma3/ -cmake --build cmake-out/examples/models/gemma3 --target gemma3_e2e_runner --config Release +# Build the Gemma3 runner with CUDA enabled +make gemma3-cuda + +# Build the Gemma3 runner with CPU enabled +make gemma3-cpu ``` ## Running the model diff --git a/examples/models/glm/__init__.py b/examples/models/glm/__init__.py new file mode 100644 index 00000000000..aef380e7f6b --- /dev/null +++ b/examples/models/glm/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.glm.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class GLMModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "GLMModel", + "convert_weights", +] diff --git a/examples/models/glm/config/1_5b_config.json b/examples/models/glm/config/1_5b_config.json new file mode 100644 index 00000000000..23576622255 --- /dev/null +++ b/examples/models/glm/config/1_5b_config.json @@ -0,0 +1,17 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 6144, + "n_heads": 16, + "head_dim": 128, + "n_kv_heads": 4, + "n_layers": 28, + "norm_eps": 1e-05, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 59264, + "use_hf_rope": true, + "attention_qkv_bias": false, + "use_qk_norm": false, + "model_architecture" : "GlmForCausalLM" +} diff --git a/examples/models/glm/convert_weights.py b/examples/models/glm/convert_weights.py new file mode 100644 index 00000000000..0568c9dccec --- /dev/null +++ b/examples/models/glm/convert_weights.py @@ -0,0 +1,79 @@ +import argparse +import os +from typing import Dict + +import torch +from safetensors.torch import load_file +from torchtune.models.convert_weights import get_mapped_key + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_GLM_FROM_META = { + "tok_embeddings.weight": "model.embed_tokens.weight", + "norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", + "layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight", + "layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight", + "layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight", + "layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight", + "layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight", + "layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight", + "layers.{}.feed_forward.gate_up_proj.weight": "model.layers.{}.mlp.gate_up_proj.weight", + "layers.{}.feed_forward.down_proj.weight": "model.layers.{}.mlp.down_proj.weight", +} + + +def glm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _GLM_FROM_META.items()} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + pt_path = os.path.join(input_dir, "model.safetensors") + print("Loading checkpoint from file...") + sd = load_file(pt_path) + + print("Converting checkpoint...") + sd = glm_tune_to_meta(sd) + + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser(description="Convert GLM weights to Meta format.") + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/granite/__init__.py b/examples/models/granite/__init__.py new file mode 100644 index 00000000000..723ae7f561d --- /dev/null +++ b/examples/models/granite/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.granite.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class GraniteModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "GraniteModel", + "convert_weights", +] diff --git a/examples/models/granite/config/2b_config.json b/examples/models/granite/config/2b_config.json new file mode 100644 index 00000000000..a1523af6a0e --- /dev/null +++ b/examples/models/granite/config/2b_config.json @@ -0,0 +1,19 @@ +{ + "dim": 2048, + "attention_qkv_bias": false, + "attention_multiplier": 0.015625, + "bos_idx": 0, + "embedding_scale_factor": 12.0, + "eos_idx": 0, + "act_fn": "silu", + "hidden_dim": 8192, + "n_heads": 32, + "n_layers": 40, + "n_kv_heads": 8, + "norm_eps": 1e-05, + "rope_theta": 10000000.0, + "vocab_size": 49159, + "use_hf_rope": false, + "residual_multiplier": 0.22, + "logits_scaling": 8.0 +} diff --git a/examples/models/granite/convert_weights.py b/examples/models/granite/convert_weights.py new file mode 100644 index 00000000000..06dd931bd21 --- /dev/null +++ b/examples/models/granite/convert_weights.py @@ -0,0 +1,106 @@ +import argparse + +import json +import os +from typing import Dict + +import torch +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + + +# Weight mappings from Granite 3's checkpoint to ExecuTorch's transformer parameters. +_GRANITE_TO_EXECUTORCH = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", +} + + +def granite_to_executorch( + state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Convert the state dict so that it matches what ExecuTorch's transformer definition expects. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _GRANITE_TO_EXECUTORCH) + converted_state_dict[new_key] = value + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + return converted_state_dict + + +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = granite_to_executorch(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Granite weights to ExecuTorch transformer format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/CMakePresets.json b/examples/models/llama/CMakePresets.json new file mode 100644 index 00000000000..b7be1b7e174 --- /dev/null +++ b/examples/models/llama/CMakePresets.json @@ -0,0 +1,67 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "llama-release", + "displayName": "Llama runner in Release mode", + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/llama", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "llama-debug", + "displayName": "Llama runner in Debug mode", + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/llama", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out" + } + } + ], + "buildPresets": [ + { + "name": "llama-release", + "displayName": "Build Llama runner in Release mode", + "configurePreset": "llama-release", + "targets": ["llama_main"] + }, + { + "name": "llama-debug", + "displayName": "Build Llama runner in Debug mode", + "configurePreset": "llama-debug", + "targets": ["llama_main"] + } + ], + "workflowPresets": [ + { + "name": "llama-release", + "displayName": "Configure and build Llama runner in Release mode", + "steps": [ + { + "type": "configure", + "name": "llama-release" + }, + { + "type": "build", + "name": "llama-release" + } + ] + }, + { + "name": "llama-debug", + "displayName": "Configure and build Llama runner in Debug mode", + "steps": [ + { + "type": "configure", + "name": "llama-debug" + }, + { + "type": "build", + "name": "llama-debug" + } + ] + } + ] +} diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index 0a81abdeee6..0ca39a3c80d 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -235,21 +235,16 @@ If you're interested in deploying on non-CPU backends, [please refer the non-cpu ## Step 3: Run on your computer to validate 1. Build executorch with optimized CPU performance as follows. Build options available [here](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L59). - ``` - cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out - - cmake --build cmake-out -j16 --target install --config Release - ``` +``` +cmake --workflow llm-release +``` Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions. 2. Build llama runner. ``` -cmake -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out/examples/models/llama \ - examples/models/llama - -cmake --build cmake-out/examples/models/llama -j16 --config Release +pushd examples/models/llama +cmake --workflow --preset llama-release +popd ``` 3. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama/main.cpp#L18-L40). diff --git a/examples/models/llama/evaluate/eager_eval.py b/examples/models/llama/evaluate/eager_eval.py index da4742cfc96..9d5d7ad447b 100644 --- a/examples/models/llama/evaluate/eager_eval.py +++ b/examples/models/llama/evaluate/eager_eval.py @@ -69,8 +69,8 @@ def device(self): def tok_encode(self, string: str, **kwargs): # pyre-ignore return self._tokenizer.encode(string, bos=False, eos=False) - def tok_decode(self, tokens): - return self._tokenizer.decode(tokens) + def tok_decode(self, tokens, **kwargs): + return self._tokenizer.decode([tokens] if isinstance(tokens, int) else tokens) def _model_call(self, inps): if self._use_kv_cache: diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0d6dc87de2f..180cda46207 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -895,7 +895,6 @@ def _to_edge_and_lower_llama_xnnpack( if gen_tag_fn is not None: from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, - external_constants_pass, ) assert ( @@ -906,18 +905,14 @@ def _to_edge_and_lower_llama_xnnpack( gen_tag_fn=gen_tag_fn, ) - # Also add a pass for 'to_executorch' to tag weights that aren't delegated. - additional_passes.append( - partial(external_constants_pass, gen_tag_fn=gen_tag_fn) - ) - builder = builder.to_edge_transform_and_lower(partitioners) if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) - # we need builder.export_program - - return builder.to_executorch(passes=additional_passes) + # Add gen_tag_fn to tag non-delegated weights as well. + return builder.to_executorch( + passes=additional_passes, external_constants_tag=gen_tag_fn + ) def _to_edge_and_lower_llama_openvino( diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 899ea37d5be..b655a619b26 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -7,11 +7,14 @@ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */ +#include #include #include #include -#include +#ifdef ET_EVENT_TRACER_ENABLED +#include +#endif #if defined(ET_USE_THREADPOOL) #include @@ -64,6 +67,11 @@ DEFINE_int32( DEFINE_bool(warmup, false, "Whether to run a warmup run."); +DEFINE_string( + etdump_path, + "etdump.in", + "If an etdump path is provided, generate an ETDump file at the specified path for profiling purposes."); + // Helper function to parse comma-separated string lists std::vector parseStringList(const std::string& input) { std::vector result; @@ -117,9 +125,26 @@ int32_t main(int32_t argc, char** argv) { ->_unsafe_reset_threadpool(num_performant_cores); } #endif + +#ifdef ET_EVENT_TRACER_ENABLED + // Create ETDumpGen and get raw pointer reference for later access + auto etdump_gen_ptr = std::make_unique(); + executorch::etdump::ETDumpGen* etdump_gen = etdump_gen_ptr.get(); +#endif + // create llama runner std::unique_ptr<::executorch::extension::llm::TextLLMRunner> runner = - example::create_llama_runner(model_path, tokenizer_path, data_paths); + example::create_llama_runner( + model_path, + tokenizer_path, + data_paths, + temperature, +#ifdef ET_EVENT_TRACER_ENABLED + std::move(etdump_gen_ptr) +#else + nullptr +#endif + ); if (runner == nullptr) { ET_LOG(Error, "Failed to create llama runner"); @@ -157,5 +182,25 @@ int32_t main(int32_t argc, char** argv) { return 1; } +#ifdef ET_EVENT_TRACER_ENABLED + if (etdump_gen != nullptr) { + executorch::etdump::ETDumpResult result = etdump_gen->get_etdump_data(); + if (result.buf != nullptr && result.size > 0) { + FILE* f = fopen(FLAGS_etdump_path.c_str(), "w+"); + if (f == nullptr) { + ET_LOG( + Error, + "Failed to open etdump file at path: %s", + FLAGS_etdump_path.c_str()); + } else { + fwrite((uint8_t*)result.buf, 1, result.size, f); + fclose(f); + ET_LOG(Info, "ETDump file written to: %s", FLAGS_etdump_path.c_str()); + } + free(result.buf); + } + } +#endif + return 0; } diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 3f82286b8ed..a0e9eb70498 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -49,6 +49,9 @@ class ModelArgs: model_architecture: str = ( "LlamaForCausalLM" # This setting is currently only supported for the QNN backend ) + attention_multiplier: Optional[float] = ( + None # Scaling factor 1/sqrt(d_k) in attention formula + ) norm_eps: float = 1e-5 post_attention_norm: bool = False post_ffn_norm: bool = False @@ -75,6 +78,9 @@ class ModelArgs: # at runtime. Enable it only necessary (e.g., use perplexity tools that requires # logits for all input tokens.) generate_full_logits: bool = False + logits_scaling: Optional[float] = ( + None # Scaling factor applied to the logits of model, functioning similarly to a temperature parameter. + ) enable_dynamic_shape: bool = False # export model with dynamic shape support # A dictionary mapping from pruned token-id to original token-id input_prune_map: Optional[Dict[int, int]] = None @@ -85,6 +91,9 @@ class ModelArgs: apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm + residual_multiplier: Optional[float] = ( + None # Scaling factor applied to the residual hidden states + ) use_hf_rope: bool = False # Use HuggingFace's RoPE implementation no_rope_layer_interval: Optional[int] = ( None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795). @@ -122,6 +131,9 @@ class ModelArgs: attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # Hybrid models can have layer types different from attention layer_types: Optional[list] = None + model_architecture: Optional[str] = ( + None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now. + ) def __post_init__(self): if self.n_kv_heads is None: diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 19ed9f88339..d2db805405e 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -36,22 +36,32 @@ std::unique_ptr create_llama_runner( const std::string& model_path, const std::string& tokenizer_path, std::optional data_path, - float temperature) { + float temperature, + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) { if (data_path.has_value()) { std::vector data_files; data_files.push_back(data_path.value()); return create_llama_runner( - model_path, tokenizer_path, std::move(data_files), temperature); + model_path, + tokenizer_path, + std::move(data_files), + temperature, + std::move(event_tracer)); } return create_llama_runner( - model_path, tokenizer_path, std::vector(), temperature); + model_path, + tokenizer_path, + std::vector(), + temperature, + std::move(event_tracer)); } std::unique_ptr create_llama_runner( const std::string& model_path, const std::string& tokenizer_path, std::vector data_files, - float temperature) { + float temperature, + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) { ET_LOG( Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", @@ -70,7 +80,11 @@ std::unique_ptr create_llama_runner( return nullptr; } return llm::create_text_llm_runner( - model_path, std::move(tokenizer), data_files); + model_path, + std::move(tokenizer), + data_files, + temperature, + std::move(event_tracer)); } } // namespace example diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 728ae57efa8..10225fcb81d 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -28,13 +28,15 @@ std::unique_ptr create_llama_runner( const std::string& model_path, const std::string& tokenizer_path, std::optional data_path, - float temperature = -1.0f); + float temperature = -1.0f, + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr); std::unique_ptr create_llama_runner( const std::string& model_path, const std::string& tokenizer_path, std::vector data_files = {}, - float temperature = -1.0f); + float temperature = -1.0f, + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr); std::unique_ptr load_llama_tokenizer( const std::string& tokenizer_path, diff --git a/examples/models/llama/runner/targets.bzl b/examples/models/llama/runner/targets.bzl index fd298ee628e..9c0b7265159 100644 --- a/examples/models/llama/runner/targets.bzl +++ b/examples/models/llama/runner/targets.bzl @@ -28,6 +28,9 @@ def define_common_targets(): exported_headers = [ "runner.h", ], + deps = [ + "//executorch/devtools/etdump:etdump_flatcc", + ], preprocessor_flags = [ "-DUSE_ATEN_LIB", ] if aten else [], diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 9e49f9e4e15..a9412d513c7 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -159,13 +159,27 @@ def quantize( # noqa C901 from torchao.utils import unwrap_tensor_subclass def filter_fn(m, fqn): + # Check if it's a regular nn.Linear is_linear = isinstance(m, nn.Linear) + + # Check if it's a LoRALinear (which has a base weight parameter to quantize) + is_lora_linear = False + try: + from executorch.examples.models.llama.lora import LoRALinear + + is_lora_linear = isinstance(m, LoRALinear) + except ImportError: + pass + + # Check if the weight shape is compatible with group size has_shape_compatible_with_group_size = False - if is_linear: + if is_linear or is_lora_linear: has_shape_compatible_with_group_size = ( m.weight.shape[1] % group_size == 0 ) - return is_linear and has_shape_compatible_with_group_size + return ( + is_linear or is_lora_linear + ) and has_shape_compatible_with_group_size quantize_( model, diff --git a/examples/models/llama/targets.bzl b/examples/models/llama/targets.bzl index 66c5dacc8e9..42512145eed 100644 --- a/examples/models/llama/targets.bzl +++ b/examples/models/llama/targets.bzl @@ -19,6 +19,7 @@ def define_common_targets(): "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/threadpool:threadpool", "//executorch/extension/threadpool:cpuinfo_utils", + "//executorch/devtools/etdump:etdump_flatcc" + aten_suffix, ], external_deps = [ "gflags", diff --git a/examples/models/llava/CMakePresets.json b/examples/models/llava/CMakePresets.json new file mode 100644 index 00000000000..0ca4c543969 --- /dev/null +++ b/examples/models/llava/CMakePresets.json @@ -0,0 +1,38 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "llava", + "displayName": "Llava runner", + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/llava", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out" + } + } + ], + "buildPresets": [ + { + "name": "llava", + "displayName": "Build Llava runner", + "configurePreset": "llava", + "targets": ["llava_main"] + } + ], + "workflowPresets": [ + { + "name": "llava", + "displayName": "Configure and build Llava runner", + "steps": [ + { + "type": "configure", + "name": "llava" + }, + { + "type": "build", + "name": "llava" + } + ] + } + ] +} diff --git a/examples/models/moshi/mimi/install_requirements.sh b/examples/models/moshi/mimi/install_requirements.sh index bddd960f8a7..20273f5fdac 100755 --- a/examples/models/moshi/mimi/install_requirements.sh +++ b/examples/models/moshi/mimi/install_requirements.sh @@ -7,7 +7,7 @@ set -x -conda install -c conda-forge "ffmpeg<8" -y +sudo apt install ffmpeg -y pip install torchcodec==0.7.0.dev20251012 --extra-index-url https://download.pytorch.org/whl/nightly/cpu pip install moshi==0.2.11 pip install bitsandbytes soundfile einops diff --git a/examples/models/phi-3-mini/README.md b/examples/models/phi-3-mini/README.md index 86160e0b39a..dac378213d8 100644 --- a/examples/models/phi-3-mini/README.md +++ b/examples/models/phi-3-mini/README.md @@ -30,9 +30,7 @@ The model artifact `model.pte` size is about 2.0GB. 3. Build and run the model. - Build executorch with LLM preset: ``` -cmake --preset llm -DCMAKE_INSTALL_PREFIX=cmake-out - -cmake --build cmake-out -j16 --target install --config Release +cmake --workflow llm-release ``` - Build Phi-3-mini runner. ``` diff --git a/examples/models/toy_model/__init__.py b/examples/models/toy_model/__init__.py index 333a625af1b..87456e3fd4c 100644 --- a/examples/models/toy_model/__init__.py +++ b/examples/models/toy_model/__init__.py @@ -10,6 +10,7 @@ Conv1dModule, LinearModule, MulModule, + SdpaModule, SoftmaxModule, ) @@ -19,5 +20,6 @@ Conv1dModule, LinearModule, MulModule, + SdpaModule, SoftmaxModule, ] diff --git a/examples/models/toy_model/model.py b/examples/models/toy_model/model.py index e1dd290b829..a31149c29af 100644 --- a/examples/models/toy_model/model.py +++ b/examples/models/toy_model/model.py @@ -105,3 +105,39 @@ def get_eager_model(self) -> torch.nn.Module: def get_example_inputs(self): return (torch.randn(1, 3, 10),) + + +class SdpaModule(torch.nn.Module, EagerModelBase): + def __init__(self): + super().__init__() + + def forward(self, query, key, value): + out = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + return out + + def get_eager_model(self) -> torch.nn.Module: + return self + + def get_example_inputs(self): + # Input shape: (batch, num_heads, seq_len, head_dim) + batch_size = 2 + num_heads = 8 + seq_len = 128 + head_dim = 64 + query = torch.randn( + batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16 + ) + key = torch.randn( + batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16 + ) + value = torch.randn( + batch_size, num_heads, seq_len, head_dim, dtype=torch.bfloat16 + ) + return (query, key, value) diff --git a/examples/models/voxtral/CMakeLists.txt b/examples/models/voxtral/CMakeLists.txt index 866d17160ba..24a1096c889 100644 --- a/examples/models/voxtral/CMakeLists.txt +++ b/examples/models/voxtral/CMakeLists.txt @@ -39,18 +39,16 @@ executorch_target_link_options_shared_lib(executorch) set(link_libraries executorch gflags) set(_srcs multimodal.cpp) -list( - APPEND - link_libraries - optimized_native_cpu_ops_lib - quantized_ops_lib - custom_ops - cpublas - eigen_blas -) +# Common ops for all builds +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) -executorch_target_link_options_shared_lib(quantized_ops_lib) -executorch_target_link_options_shared_lib(custom_ops) + +# CPU-only builds need quantized and custom ops +if(NOT EXECUTORCH_BUILD_CUDA AND MSVC) + list(APPEND link_libraries quantized_ops_lib custom_ops) + executorch_target_link_options_shared_lib(quantized_ops_lib) + executorch_target_link_options_shared_lib(custom_ops) +endif() # XNNPACK if(TARGET xnnpack_backend) @@ -89,8 +87,11 @@ list( # Link CUDA backend if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) - list(APPEND link_libraries aoti_cuda) - executorch_target_link_options_shared_lib(aoti_cuda) + list(APPEND link_libraries aoti_cuda_backend) + if(NOT MSVC) + # On non-MSVC, use shared lib options + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() endif() if(EXECUTORCH_BUILD_METAL) @@ -104,7 +105,7 @@ list(APPEND link_libraries tokenizers::tokenizers) add_executable(voxtral_runner ${_srcs}) if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options_gc_sections(voxtral_runner) - if(NOT APPLE) + if(NOT APPLE AND NOT MSVC) target_link_options(voxtral_runner PRIVATE "LINKER:-s") endif() endif() @@ -112,3 +113,14 @@ endif() target_include_directories(voxtral_runner PUBLIC ${_common_include_directories}) target_link_libraries(voxtral_runner PUBLIC ${link_libraries}) target_compile_options(voxtral_runner PUBLIC ${_common_compile_options}) + +# On Windows, copy required DLLs to the executable directory +if(MSVC AND EXECUTORCH_BUILD_CUDA) + add_custom_command( + TARGET voxtral_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ + $ + COMMENT "Copying aoti_cuda_shims.dll to voxtral_runner directory" + ) +endif() diff --git a/examples/models/voxtral/CMakePresets.json b/examples/models/voxtral/CMakePresets.json new file mode 100644 index 00000000000..b44eca42f74 --- /dev/null +++ b/examples/models/voxtral/CMakePresets.json @@ -0,0 +1,109 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "voxtral-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/voxtral", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "voxtral-cpu", + "displayName": "Voxtral runner (CPU)", + "inherits": ["voxtral-base"] + }, + { + "name": "voxtral-cuda", + "displayName": "Voxtral runner (CUDA)", + "inherits": ["voxtral-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + }, + { + "name": "voxtral-metal", + "displayName": "Voxtral runner (Metal)", + "inherits": ["voxtral-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + } + ], + "buildPresets": [ + { + "name": "voxtral-cpu", + "displayName": "Build Voxtral runner (CPU)", + "configurePreset": "voxtral-cpu", + "targets": ["voxtral_runner"] + }, + { + "name": "voxtral-cuda", + "displayName": "Build Voxtral runner (CUDA)", + "configurePreset": "voxtral-cuda", + "targets": ["voxtral_runner"] + }, + { + "name": "voxtral-metal", + "displayName": "Build Voxtral runner (Metal)", + "configurePreset": "voxtral-metal", + "targets": ["voxtral_runner"] + } + ], + "workflowPresets": [ + { + "name": "voxtral-cpu", + "displayName": "Configure and build Voxtral runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "voxtral-cpu" + }, + { + "type": "build", + "name": "voxtral-cpu" + } + ] + }, + { + "name": "voxtral-cuda", + "displayName": "Configure and build Voxtral runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "voxtral-cuda" + }, + { + "type": "build", + "name": "voxtral-cuda" + } + ] + }, + { + "name": "voxtral-metal", + "displayName": "Configure and build Voxtral runner (Metal)", + "steps": [ + { + "type": "configure", + "name": "voxtral-metal" + }, + { + "type": "build", + "name": "voxtral-metal" + } + ] + } + ] +} diff --git a/examples/models/voxtral/README.md b/examples/models/voxtral/README.md index 7ab35819d80..30da684722e 100644 --- a/examples/models/voxtral/README.md +++ b/examples/models/voxtral/README.md @@ -122,51 +122,20 @@ python -m executorch.extension.audio.mel_spectrogram \ ### Building for CPU (XNNPack) ``` -# Build and install ExecuTorch -cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -DEXECUTORCH_ENABLE_LOGGING=ON && cmake --build cmake-out -j16 --target install --config Release - # Build and install Voxtral runner -cmake -DCMAKE_INSTALL_PREFIX=cmake-out -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release -Bcmake-out/examples/models/voxtral examples/models/voxtral && cmake --build cmake-out/examples/models/voxtral -j16 --config Release +make voxtral-cpu ``` ### Building for CUDA ``` -# Install ExecuTorch with CUDA support -./install_executorch.sh - -# Build the multimodal runner with CUDA -cmake --preset llm \ - -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out -S. -cmake --build cmake-out -j16 --target install --config Release - -cmake -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -Sexamples/models/voxtral \ - -Bcmake-out/examples/models/voxtral/ -cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release +# Build Voxtral runner with CUDA +make voxtral-cuda ``` ### Building for Metal ``` -# Install ExecuTorch with Metal support -CMAKE_ARGS="-DEXECUTORCH_BUILD_METAL=ON" ./install_executorch.sh - -# Build the multimodal runner with Metal -cmake --preset llm \ - -DEXECUTORCH_BUILD_METAL=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out -S. -cmake --build cmake-out -j16 --target install --config Release - -cmake -DEXECUTORCH_BUILD_METAL=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -Sexamples/models/voxtral \ - -Bcmake-out/examples/models/voxtral/ -cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release +# Build Voxtral runner with Metal +make voxtral-metal ``` ## Running the model @@ -238,4 +207,4 @@ afconvert -f WAVE -d LEI16 call_samantha_hall.aiff call_samantha_hall.wav ## Android and iOS mobile demo apps -We have example mobile demo apps for Android and iOS (using XNNPACK) [here](https://github.com/meta-pytorch/executorch-examples/tree/main/llm) \ No newline at end of file +We have example mobile demo apps for Android and iOS (using XNNPACK) [here](https://github.com/meta-pytorch/executorch-examples/tree/main/llm) diff --git a/examples/models/whisper/CMakeLists.txt b/examples/models/whisper/CMakeLists.txt index 70f5892baa7..295779fd626 100644 --- a/examples/models/whisper/CMakeLists.txt +++ b/examples/models/whisper/CMakeLists.txt @@ -69,8 +69,8 @@ list( # Link CUDA backend if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) - list(APPEND _link_libraries aoti_cuda) - executorch_target_link_options_shared_lib(aoti_cuda) + list(APPEND _link_libraries aoti_cuda_backend) + executorch_target_link_options_shared_lib(aoti_cuda_backend) endif() if(EXECUTORCH_BUILD_METAL) diff --git a/examples/models/whisper/CMakePresets.json b/examples/models/whisper/CMakePresets.json new file mode 100644 index 00000000000..a081ad5be29 --- /dev/null +++ b/examples/models/whisper/CMakePresets.json @@ -0,0 +1,109 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "whisper-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/whisper", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "whisper-cpu", + "displayName": "Whisper runner (CPU)", + "inherits": ["whisper-base"] + }, + { + "name": "whisper-cuda", + "displayName": "Whisper runner (CUDA)", + "inherits": ["whisper-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + }, + { + "name": "whisper-metal", + "displayName": "Whisper runner (Metal)", + "inherits": ["whisper-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + } + ], + "buildPresets": [ + { + "name": "whisper-cpu", + "displayName": "Build Whisper runner (CPU)", + "configurePreset": "whisper-cpu", + "targets": ["whisper_runner"] + }, + { + "name": "whisper-cuda", + "displayName": "Build Whisper runner (CUDA)", + "configurePreset": "whisper-cuda", + "targets": ["whisper_runner"] + }, + { + "name": "whisper-metal", + "displayName": "Build Whisper runner (Metal)", + "configurePreset": "whisper-metal", + "targets": ["whisper_runner"] + } + ], + "workflowPresets": [ + { + "name": "whisper-cpu", + "displayName": "Configure and build Whisper runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "whisper-cpu" + }, + { + "type": "build", + "name": "whisper-cpu" + } + ] + }, + { + "name": "whisper-cuda", + "displayName": "Configure and build Whisper runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "whisper-cuda" + }, + { + "type": "build", + "name": "whisper-cuda" + } + ] + }, + { + "name": "whisper-metal", + "displayName": "Configure and build Whisper runner (Metal)", + "steps": [ + { + "type": "configure", + "name": "whisper-metal" + }, + { + "type": "build", + "name": "whisper-metal" + } + ] + } + ] +} diff --git a/examples/models/whisper/README.md b/examples/models/whisper/README.md index a4025441f7e..329ef55e8b6 100644 --- a/examples/models/whisper/README.md +++ b/examples/models/whisper/README.md @@ -20,23 +20,22 @@ module to generate the spectrogram tensor. ## Build -Currently we have CUDA build support only. CPU and Metal backend builds are WIP. +Currently we have CUDA and Metal build support. -```bash -# Install ExecuTorch libraries: -cmake --preset llm -DEXECUTORCH_BUILD_CUDA=ON -DCMAKE_INSTALL_PREFIX=cmake-out -DCMAKE_BUILD_TYPE=Release . -Bcmake-out -cmake --build cmake-out -j$(nproc) --target install --config Release - -# Build the runner: -cmake \ - -B cmake-out/examples/models/whisper \ - -S examples/models/whisper -cmake --build cmake-out/examples/models/whisper -j$(nproc) +For CPU: +``` +make whisper-cpu ``` -The first cmake command build produces a static library named `extension_asr_runner`. The second cmake command links it into your -application together with the standard ExecuTorch runtime libraries and the -tokenizer target (`tokenizers::tokenizers`). +For CUDA: +``` +make whisper-cuda +``` + +For Metal: +``` +make whisper-metal +``` ## Usage @@ -44,6 +43,8 @@ tokenizer target (`tokenizers::tokenizers`). Use [Optimum-ExecuTorch](https://github.com/huggingface/optimum-executorch) to export a Whisper model from Hugging Face: +#### CUDA backend: + ```bash optimum-cli export executorch \ --model openai/whisper-small \ @@ -58,9 +59,27 @@ This command generates: - `model.pte` — Compiled Whisper model - `aoti_cuda_blob.ptd` — Weight data file for CUDA backend +#### Metal backend: + +```bash +optimum-cli export executorch \ + --model openai/whisper-small \ + --task automatic-speech-recognition \ + --recipe metal \ + --dtype bfloat16 \ + --output_dir ./ +``` + +This command generates: +- `model.pte` — Compiled Whisper model +- `aoti_metal_blob.ptd` — Weight data file for Metal backend + +### Preprocessor + Export a preprocessor to convert raw audio to mel-spectrograms: ```bash +# Use --feature_size 128 for whisper-large-v3 and whisper-large-v3-turbo python -m executorch.extension.audio.mel_spectrogram \ --feature_size 80 \ --stack_output \ @@ -70,7 +89,7 @@ python -m executorch.extension.audio.mel_spectrogram \ ### Quantization -Export quantized models to reduce size and improve performance: +Export quantized models to reduce size and improve performance (Not enabled for Metal yet): ```bash # 4-bit tile packed quantization for encoder @@ -90,14 +109,22 @@ optimum-cli export executorch \ ### Download Tokenizer -Download the tokenizer files required for inference: +Download the tokenizer files required for inference according to your model version: +**For Whisper Small:** ```bash curl -L https://huggingface.co/openai/whisper-small/resolve/main/tokenizer.json -o tokenizer.json curl -L https://huggingface.co/openai/whisper-small/resolve/main/tokenizer_config.json -o tokenizer_config.json curl -L https://huggingface.co/openai/whisper-small/resolve/main/special_tokens_map.json -o special_tokens_map.json ``` +**For Whisper Large v2:** +```bash +curl -L https://huggingface.co/openai/whisper-large-v2/resolve/main/tokenizer.json -o tokenizer.json +curl -L https://huggingface.co/openai/whisper-large-v2/resolve/main/tokenizer_config.json -o tokenizer_config.json +curl -L https://huggingface.co/openai/whisper-large-v2/resolve/main/special_tokens_map.json -o special_tokens_map.json +``` + ### Prepare Audio Generate test audio or use an existing WAV file. The model expects 16kHz mono audio. @@ -111,6 +138,8 @@ python -c "from datasets import load_dataset; import soundfile as sf; sample = l After building the runner (see [Build](#build) section), execute it with the exported model and audio: +#### CUDA backend: + ```bash # Set library path for CUDA dependencies export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH @@ -124,3 +153,16 @@ cmake-out/examples/models/whisper/whisper_runner \ --processor_path whisper_preprocessor.pte \ --temperature 0 ``` + +#### Metal backend: + +```bash +# Run the Whisper runner +cmake-out/examples/models/whisper/whisper_runner \ + --model_path model.pte \ + --data_path aoti_metal_blob.ptd \ + --tokenizer_path ./ \ + --audio_path output.wav \ + --processor_path whisper_preprocessor.pte \ + --temperature 0 +``` diff --git a/examples/models/whisper/main.cpp b/examples/models/whisper/main.cpp index b4462e2c39a..080106c8915 100644 --- a/examples/models/whisper/main.cpp +++ b/examples/models/whisper/main.cpp @@ -109,7 +109,11 @@ int main(int argc, char** argv) { executorch::extension::asr::AsrTranscribeConfig config; config.max_new_tokens = FLAGS_max_new_tokens; config.temperature = static_cast(FLAGS_temperature); - config.decoder_start_token_id = 50257; + + // All Whisper models from HuggingFace now use the v3 tokenizer format + // where token 50257 = <|endoftext|> and token 50258 = <|startoftranscript|> + config.decoder_start_token_id = 50258; + ET_LOG(Info, "Using decoder_start_token_id=50258"); auto result = runner.transcribe(features, config, [&](const std::string& piece) { diff --git a/examples/nxp/aot_neutron_compile.py b/examples/nxp/aot_neutron_compile.py index 9e3f85c8176..3f6ce7b1910 100644 --- a/examples/nxp/aot_neutron_compile.py +++ b/examples/nxp/aot_neutron_compile.py @@ -9,7 +9,6 @@ import io import logging from collections import defaultdict -from typing import Iterator import executorch.extension.pybindings.portable_lib import executorch.kernels.quantized # noqa F401 @@ -25,6 +24,7 @@ from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import post_training_quantize from executorch.devtools.visualization.visualization_utils import ( visualize_with_clusters, ) @@ -37,7 +37,6 @@ ) from executorch.extension.export_util import save_pte_program from torch.export import export -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model from .models.mobilenet_v2 import MobilenetV2 @@ -109,44 +108,6 @@ def get_model_and_inputs_from_name(model_name: str): } -def post_training_quantize( - model, - calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]], - neutron_target_spec: NeutronTargetSpec, -): - """Quantize the provided model. - - :param model: Aten model to quantize. - :param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model - input. Or an iterator over such tuples. - :param _neutron_target_spec: The functionality for probing the properties of Neutron Target. - """ - # Based on executorch.examples.arm.aot_amr_compiler.quantize - logging.info("Quantizing model") - logging.debug(f"---> Original model: {model}") - quantizer = NeutronQuantizer(neutron_target_spec) - - m = prepare_pt2e(model, quantizer) - # Calibration: - logging.debug("Calibrating model") - - def _get_batch_size(data): - return data[0].shape[0] - - if not isinstance( - calibration_inputs, tuple - ): # Assumption that calibration_inputs is finite. - for i, data in enumerate(calibration_inputs): - if i % (1000 // _get_batch_size(data)) == 0: - logging.debug(f"{i * _get_batch_size(data)} calibration inputs done") - m(*data) - else: - m(*calibration_inputs) - m = convert_pt2e(m) - logging.debug(f"---> Quantized model: {m}") - return m - - if __name__ == "__main__": # noqa C901 parser = argparse.ArgumentParser() parser.add_argument( @@ -254,7 +215,8 @@ def _get_batch_size(data): "No calibration inputs available, using the example inputs instead" ) calibration_inputs = example_inputs - module = post_training_quantize(module, calibration_inputs, neutron_target_spec) + quantizer = NeutronQuantizer(neutron_target_spec) + module = post_training_quantize(module, calibration_inputs, quantizer) if args.so_library is not None: logging.debug(f"Loading libraries: {args.so_library}") diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 1157554c050..a3f4a2d385a 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -18,12 +18,16 @@ * all fp32 tensors. */ +#include +#include #include #include #include +#include #include +#include #include #include #include @@ -42,8 +46,10 @@ #include #include #include +#endif -#include +#ifdef ET_BUNDLE_IO_ENABLED +#include #endif static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB @@ -54,7 +60,7 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); -DEFINE_string(data_path, "", "Path to data file."); +DEFINE_string(data_path, "", "Path to data file (.ptd)."); DEFINE_string(inputs, "", "Comma-separated list of input files"); DEFINE_string( output_file, @@ -74,10 +80,22 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); +#ifdef ET_BUNDLE_IO_ENABLED +DEFINE_double(bundleio_rtol, 0.01, "Relative tolerance for bundled IO."); +DEFINE_double(bundleio_atol, 0.01, "Absolute tolerance for bundled IO."); +#endif + using executorch::aten::ScalarType; using executorch::aten::Tensor; +#ifdef ET_BUNDLE_IO_ENABLED +using executorch::bundled_program::compute_method_output_error_stats; +using executorch::bundled_program::ErrorStats; +using executorch::bundled_program::verify_method_outputs; +#endif +using executorch::extension::BufferDataLoader; using executorch::extension::FileDataLoader; using executorch::extension::FlatTensorDataMap; +using executorch::runtime::DataLoader; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::EventTracer; @@ -142,6 +160,25 @@ class EventTraceManager { std::shared_ptr event_tracer_ptr_; }; +#ifdef ET_BUNDLE_IO_ENABLED +std::vector try_load_file(const std::filesystem::path& path) { + std::ifstream file(path, std::ios::binary | std::ios::ate); + ET_CHECK_MSG( + file.is_open(), "Could not open file '%s'", path.string().c_str()); + + const std::size_t nbytes = static_cast(file.tellg()); + file.seekg(0, std::ios::beg); + + std::vector file_data(nbytes); + ET_CHECK_MSG( + file.read(reinterpret_cast(file_data.data()), nbytes), + "Could not load contents of file '%s'", + path.string().c_str()); + + return file_data; +} +#endif + int main(int argc, char** argv) { executorch::runtime::runtime_init(); @@ -172,20 +209,86 @@ int main(int argc, char** argv) { opt_guard.emplace(); } #endif // ET_USE_THREADPOOL - // Create a loader to get the data of the program file. There are other - // DataLoaders that use mmap() or point to data that's already in memory, and - // users can create their own DataLoaders to load from arbitrary sources. - const char* model_path = FLAGS_model_path.c_str(); - Result loader = FileDataLoader::from(model_path); - ET_CHECK_MSG( - loader.ok(), - "FileDataLoader::from() failed: 0x%" PRIx32, - (uint32_t)loader.error()); - // Load .ptd file if provided + bool bundle_io = false; + size_t program_data_len = 0; + const void* program_data = nullptr; + +#ifdef ET_BUNDLE_IO_ENABLED + std::vector model_file_data = try_load_file(FLAGS_model_path); + uint8_t* model_pte = model_file_data.data(); + size_t pte_size = model_file_data.size(); + constexpr size_t testset_idx = 0; + + // Check for bundled IO provided model. + bundle_io = executorch::bundled_program::is_bundled_program( + reinterpret_cast(model_pte), pte_size); + + if (bundle_io) { + // BundleIO bpte file is provided - dig out the actual model from the data + // area. + ET_LOG(Debug, "PTE Model with bundle io detected."); + Error status = executorch::bundled_program::get_program_data( + reinterpret_cast(model_pte), + pte_size, + &program_data, + &program_data_len); + + ET_CHECK_MSG( + status == Error::Ok, + "get_program_data() from bundle PTE failed: 0x%x" PRIx32, + static_cast(status)); + } else { + ET_LOG(Debug, "PTE Model has no bundled IO"); + } +#endif + + // Inputs can come from bundleio, as optional input file(s), or + // everything hardcoded to ones. + std::vector inputs_storage; + std::vector> input_buffers; + if (!bundle_io) { + if (!FLAGS_inputs.empty()) { + ET_LOG(Info, "Loading inputs from input file(s)."); + std::stringstream list_of_input_files(FLAGS_inputs); + std::string path; + + std::vector file_paths; + while (std::getline(list_of_input_files, path, ',')) { + file_paths.push_back(std::move(path)); + } + // First reserve number of elements to avoid vector reallocations. + inputs_storage.reserve(file_paths.size()); + + for (const auto& file_path : file_paths) { + std::ifstream input_file_handle( + file_path, std::ios::binary | std::ios::ate); + + if (!input_file_handle) { + ET_LOG(Error, "Failed to open input file: %s\n", file_path.c_str()); + return 1; + } + + std::streamsize file_size = input_file_handle.tellg(); + input_file_handle.seekg(0, std::ios::beg); + + // Reserve memory for actual file contents. + inputs_storage.emplace_back(file_size, '\0'); + + if (!input_file_handle.read(inputs_storage.back().data(), file_size)) { + ET_LOG(Error, "Failed to read input file: %s\n", file_path.c_str()); + return 1; + } + + input_buffers.emplace_back(&inputs_storage.back()[0], file_size); + } + } + } + std::unique_ptr ptd_loader; std::unique_ptr ptd_data_map; if (!FLAGS_data_path.empty()) { + ET_LOG(Info, "Loading tensor data from .ptd file."); const char* data_path = FLAGS_data_path.c_str(); Result ptd_loader_result = FileDataLoader::from(data_path); ET_CHECK_MSG( @@ -210,51 +313,41 @@ int main(int argc, char** argv) { static_cast(ptd_data_map->get_num_keys().get())); } - std::vector inputs_storage; - std::vector> input_buffers; - - std::stringstream list_of_input_files(FLAGS_inputs); - std::string path; - - // First reserve memory for number of vector elements to avoid vector - // reallocations when emplacing back. - std::vector file_paths; - while (std::getline(list_of_input_files, path, ',')) { - file_paths.push_back(std::move(path)); - } - inputs_storage.reserve(file_paths.size()); - - for (const auto& file_path : file_paths) { - std::ifstream input_file_handle( - file_path, std::ios::binary | std::ios::ate); - - if (!input_file_handle) { - ET_LOG(Error, "Failed to open input file: %s\n", file_path.c_str()); - return 1; - } - - std::streamsize file_size = input_file_handle.tellg(); - input_file_handle.seekg(0, std::ios::beg); - - // Reserve memory for actual file contents. - inputs_storage.emplace_back(file_size, '\0'); - - if (!input_file_handle.read(&inputs_storage.back()[0], file_size)) { - ET_LOG(Error, "Failed to read input file: %s\n", file_path.c_str()); - return 1; - } + // Create a loader to get the data of the program file. There are other + // DataLoaders that use mmap() or point to data that's already in memory, and + // users can create their own DataLoaders to load from arbitrary sources. + std::unique_ptr loader; - input_buffers.emplace_back(&inputs_storage.back()[0], file_size); + if (bundle_io) { + Result buffer_loader = + BufferDataLoader(program_data, program_data_len); + ET_CHECK_MSG( + buffer_loader.ok(), + "BufferDataLoader failed: 0x%" PRIx32, + static_cast(buffer_loader.error())); + ET_LOG( + Debug, + "Bundled IO PTE Model data loaded. Size: %zu bytes.", + program_data_len); + loader = std::make_unique(std::move(buffer_loader.get())); + } else { + Result file_loader = + FileDataLoader::from(FLAGS_model_path.c_str()); + ET_CHECK_MSG( + file_loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + static_cast(file_loader.error())); + loader = std::make_unique(std::move(file_loader.get())); } // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. - Result program = Program::load(&loader.get()); + Result program = Program::load(loader.get()); if (!program.ok()) { - ET_LOG(Error, "Failed to parse model file %s", model_path); + ET_LOG(Error, "Failed to parse model file %s", FLAGS_model_path.c_str()); return 1; } - ET_LOG(Info, "Model file %s is loaded.", model_path); + ET_LOG(Info, "Model file %s is loaded.", FLAGS_model_path.c_str()); // Use the first method in the program. const char* method_name = nullptr; @@ -347,9 +440,8 @@ int main(int argc, char** argv) { et_timestamp_t time_spent_executing = 0; // Run the model. for (uint32_t i = 0; i < FLAGS_num_executions; i++) { - ET_LOG(Debug, "Preparing inputs."); // Allocate input tensors and set all of their elements to 1 or to the - // contents of input_buffers if available. The `inputs` + // contents of input_buffers if available. For non bundled IO, the `inputs` // variable owns the allocated memory and must live past the last call to // `execute()`. // @@ -357,13 +449,30 @@ int main(int argc, char** argv) { // because inputs whose space gets reused by memory planning (if // any such inputs exist) will not be preserved for the next // execution. - auto inputs = executorch::extension::prepare_input_tensors( - *method, {}, input_buffers); - ET_CHECK_MSG( - inputs.ok(), - "Could not prepare inputs: 0x%" PRIx32, - (uint32_t)inputs.error()); - ET_LOG(Debug, "Inputs prepared."); + std::optional inputs; + +#ifdef ET_BUNDLE_IO_ENABLED + if (bundle_io) { + ET_LOG(Debug, "Getting inputs from bundled IO"); + Error status = executorch::bundled_program::load_bundled_input( + *method, model_pte, testset_idx); + ET_CHECK_MSG( + status == Error::Ok, + "load_bundled_input failed with status 0x%" PRIx32, + static_cast(status)); + } else +#endif + { + ET_LOG(Debug, "Preparing inputs."); + auto res = executorch::extension::prepare_input_tensors( + *method, {}, input_buffers); + ET_CHECK_MSG( + res.ok(), + "Could not prepare inputs: 0x%" PRIx32, + (uint32_t)res.error()); + inputs.emplace(std::move(res.get())); + ET_LOG(Debug, "Inputs prepared."); + } const et_timestamp_t before_execute = executorch::runtime::pal_current_ticks(); @@ -375,7 +484,7 @@ int main(int argc, char** argv) { status == Error::Ok, "Execution of method %s failed with status 0x%" PRIx32, method_name, - (uint32_t)status); + static_cast(status)); } const auto tick_ratio = et_pal_ticks_to_ns_multiplier(); constexpr auto NANOSECONDS_PER_MILLISECOND = 1000000; @@ -460,5 +569,58 @@ int main(int argc, char** argv) { ET_CHECK_MSG(status == Error::Ok, "Failed to save ETDump file."); } +#ifdef ET_BUNDLE_IO_ENABLED + if (bundle_io) { + // With bundled io we can check the result. + bool model_ok = false; + + ErrorStats stats = + compute_method_output_error_stats(*method, model_pte, testset_idx); + + if (stats.status == Error::Ok) { + ET_LOG(Info, "=== Error stats for testset %zu ===", testset_idx); + ET_LOG(Info, " mean_absolute_error: %f", stats.mean_abs_error); + ET_LOG(Info, " max_absolute_error: %f", stats.max_abs_error); + ET_LOG(Info, " mean_relative_error: %f", stats.mean_relative_error); + ET_LOG(Info, " max_relative_error: %f", stats.max_relative_error); + } else { + ET_LOG( + Info, + "=== Error calculating stats for testset %zu ERROR: 0x%x" PRIx32 + "===", + testset_idx, + static_cast(stats.status)); + } + + Error status = verify_method_outputs( + *method, + model_pte, + testset_idx, + FLAGS_bundleio_rtol, + FLAGS_bundleio_atol); + if (status == Error::Ok) { + ET_LOG(Info, "Model output match expected BundleIO bpte ref data."); + ET_LOG(Info, "TEST: BundleIO index[%zu] Test_result: PASS", testset_idx); + model_ok = true; + } else { + ET_LOG( + Error, + "Model output don't match expected BundleIO bpte ref data. rtol=%f atol=%f", + FLAGS_bundleio_rtol, + FLAGS_bundleio_atol); + ET_LOG(Error, "TEST: BundleIO index[%zu] Test_result: FAIL", testset_idx); + ET_LOG( + Error, + "Bundle verification failed with status 0x%" PRIx32, + static_cast(status)); + model_ok = false; + } + + if (!model_ok) { + return 1; + } + } +#endif + return 0; } diff --git a/examples/portable/executor_runner/targets.bzl b/examples/portable/executor_runner/targets.bzl index d1304a84bcb..61a4db43f68 100644 --- a/examples/portable/executor_runner/targets.bzl +++ b/examples/portable/executor_runner/targets.bzl @@ -18,6 +18,7 @@ def define_common_targets(): "//executorch/runtime/executor:program", "//executorch/devtools/etdump:etdump_flatcc", "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/flat_tensor:flat_tensor_data_map", "//executorch/extension/runner_util:inputs", @@ -38,6 +39,7 @@ def define_common_targets(): deps = [ "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/data_loader:buffer_data_loader", "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/flat_tensor:flat_tensor_data_map", "//executorch/extension/runner_util:inputs", diff --git a/examples/qualcomm/oss_scripts/albert.py b/examples/qualcomm/oss_scripts/albert.py index 3be48215ac6..d529e5db734 100644 --- a/examples/qualcomm/oss_scripts/albert.py +++ b/examples/qualcomm/oss_scripts/albert.py @@ -30,6 +30,9 @@ def main(args): + if args.compile_only and args.pre_gen_pte: + raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) os.makedirs(args.artifact, exist_ok=True) @@ -60,26 +63,32 @@ def main(args): module = AutoModelForMaskedLM.from_pretrained(model_name, config=config).eval() pte_filename = "albert_qnn_q16" - # lower to QNN - passes_job = get_capture_program_passes() - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - quant_dtype=QuantDtype.use_16a16w, - passes_job=passes_job, - shared_buffer=args.shared_buffer, - ) + # Skip lowering/compilation if using pre-generated PTE + if not args.pre_gen_pte: + # lower to QNN + passes_job = get_capture_program_passes() + build_executorch_binary( + module, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a16w, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) if args.compile_only: return workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" - pte_path = f"{args.artifact}/{pte_filename}.pte" + pte_path = ( + f"{args.pre_gen_pte}/{pte_filename}.pte" + if args.pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), diff --git a/examples/qualcomm/oss_scripts/bert.py b/examples/qualcomm/oss_scripts/bert.py index 0f9255cefdb..aa41df6ff4d 100644 --- a/examples/qualcomm/oss_scripts/bert.py +++ b/examples/qualcomm/oss_scripts/bert.py @@ -30,6 +30,9 @@ def main(args): + if args.compile_only and args.pre_gen_pte: + raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) os.makedirs(args.artifact, exist_ok=True) @@ -57,26 +60,32 @@ def main(args): ).eval() pte_filename = "bert_qnn_q16" - # lower to QNN - passes_job = get_capture_program_passes() - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - quant_dtype=QuantDtype.use_16a8w, - passes_job=passes_job, - shared_buffer=args.shared_buffer, - ) + # Skip lowering/compilation if using pre-generated PTE + if not args.pre_gen_pte: + # lower to QNN + passes_job = get_capture_program_passes() + build_executorch_binary( + module, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a8w, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) if args.compile_only: return workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" - pte_path = f"{args.artifact}/{pte_filename}.pte" + pte_path = ( + f"{args.pre_gen_pte}/{pte_filename}.pte" + if args.pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), diff --git a/examples/qualcomm/oss_scripts/distilbert.py b/examples/qualcomm/oss_scripts/distilbert.py index 7ca05181645..ce88f61ca5c 100644 --- a/examples/qualcomm/oss_scripts/distilbert.py +++ b/examples/qualcomm/oss_scripts/distilbert.py @@ -31,6 +31,9 @@ def main(args): + if args.compile_only and args.pre_gen_pte: + raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) os.makedirs(args.artifact, exist_ok=True) @@ -58,26 +61,32 @@ def main(args): ).eval() pte_filename = "distilbert_qnn_q16" - # lower to QNN - passes_job = get_capture_program_passes() - build_executorch_binary( - module, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - quant_dtype=QuantDtype.use_16a8w, - passes_job=passes_job, - shared_buffer=args.shared_buffer, - ) + # Skip lowering/compilation if using pre-generated PTE + if not args.pre_gen_pte: + # lower to QNN + passes_job = get_capture_program_passes() + build_executorch_binary( + module, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a8w, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) if args.compile_only: return workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" - pte_path = f"{args.artifact}/{pte_filename}.pte" + pte_path = ( + f"{args.pre_gen_pte}/{pte_filename}.pte" + if args.pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), diff --git a/examples/qualcomm/oss_scripts/eurobert.py b/examples/qualcomm/oss_scripts/eurobert.py index a856616bcf2..5e133aed0d1 100644 --- a/examples/qualcomm/oss_scripts/eurobert.py +++ b/examples/qualcomm/oss_scripts/eurobert.py @@ -35,6 +35,9 @@ def main(args): + if args.compile_only and args.pre_gen_pte: + raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + assert ( transformers.__version__ >= TRANSFORMERS_VERSION ), f"Please ensure transformers version >= {TRANSFORMERS_VERSION}, current version is {transformers.__version__}" @@ -88,33 +91,40 @@ def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): pte_filename = "eurobert_qnn_q16" - # lower to QNN - passes_job = get_capture_program_passes() - quantizer = make_quantizer( - quant_dtype=QuantDtype.use_16a16w, - ) - quantizer.add_custom_quant_annotations((annotate_eurobert,)) - with torch.no_grad(): - build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - dataset=inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - custom_quantizer=quantizer, - passes_job=passes_job, - shared_buffer=args.shared_buffer, + # Skip lowering/compilation if using pre-generated PTE + if not args.pre_gen_pte: + # lower to QNN + passes_job = get_capture_program_passes() + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_16a16w, ) + quantizer.add_custom_quant_annotations((annotate_eurobert,)) + with torch.no_grad(): + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + custom_quantizer=quantizer, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) if args.compile_only: return + pte_path = ( + f"{args.pre_gen_pte}/{pte_filename}.pte" + if args.pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), build_path=f"{args.build_folder}", - pte_path=f"{args.artifact}/{pte_filename}.pte", + pte_path=pte_path, workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, host_id=args.host, diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index e6fa9a66e26..7a08cbfd881 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -2,17 +2,20 @@ ## Overview This file provides you the instructions to run LLM Decoder model with different parameters via Qualcomm HTP backend. We currently support the following models: + 1. LLAMA2 Stories 110M - 2. LLAMA3.2 1B - 3. LLAMA3.2 3B - 4. Codegen2 1B - 5. Gemma 2B - 6. Gemma3 1B - 7. Phi4-mini-instruct - 8. QWEN2.5 0.5B / 1.5B - 9. QWEN3 0.6B / 1.7B - 10. SmolLM2 135M - 11. SmolLM3 3B + 1. LLAMA3.2 1B + 1. LLAMA3.2 3B + 1. Codegen2 1B + 1. Gemma 2B + 1. Gemma3 1B + 1. GLM 1.5B + 1. Granite3.3 2B + 1. Phi4-mini-instruct + 1. QWEN2.5 0.5B / 1.5B + 1. QWEN3 0.6B / 1.7B + 1. SmolLM2 135M + 1. SmolLM3 3B We offer the following modes to execute the model: @@ -63,7 +66,10 @@ Follow the [instructions](https://www.llama.com/) to download models. At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`. -### Step3: Run default examples using hybrid mode for smaller models and kv mode for larger models. +### Step3: Run default examples. +#### Note: +All example scripts below use hybrid mode, which is optimized for on-device performance. However, compiling a model in hybrid mode can consume a significant amount of memory on the host machine—sometimes up to ~100 GB. If your host machine has limited memory, it is highly recommended to switch from `--model_mode hybrid` to `--model_mode kv` and remove the `--prefill_ar_len` flag. + #### LLAMA2 ```bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time" @@ -78,7 +84,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### LLAMA3.2 3B Instruct Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### Codegen2 @@ -100,10 +106,22 @@ Default example using hybrid mode python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` +#### GLM 1.5B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model glm-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` + +#### Granite3.3 2B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model granite_3_3-2b_instruct --prompt "I would like to learn python, could you teach me with a simple example?" --run_lm_eval --task hellaswag --limit 10 +``` + #### Phi4-mini-instruct Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### QWEN2.5 0.5B @@ -115,7 +133,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### QWEN2.5 1.5B Default example using kv mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### QWEN3 0.6B @@ -127,7 +145,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### QWEN3 1.7B Default example using hybrid mode ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### SmolLM2 @@ -139,7 +157,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### SmolLM3 Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` ### KV Cache update mechanism @@ -227,24 +245,24 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL #### Perplexity Evaluation This script supports perplexity evaluation and is capable of assessing perplexity scores across 3 phases: prepare_pt2e(CPU FP), convert_pt2e(CPU QDQ), QNN on device. -To evaluate the perplexity across all 3 phases, users should provide the `--eval_perplexity` flag and specify the evaluation task. Please notice when this flag is provided, the `--prompt ${PROMPT}` will be ignored. +To evaluate the perplexity across all 3 phases, users should provide the `--run_lm_eval` flag and specify the evaluation task. Please notice when this flag is provided, the `--prompt ${PROMPT}` will be ignored. For example, using the Qwen model and 1 wikitext sample as the evaluation task, users can assess all 3 phases perplexity score in a single run by including the appropriate configuration: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --run_lm_eval --tasks wikitext --limit 1 ``` For the example script above, 1 wikitext sample is used to evaluate all 3 phases. However, there are cases where a user may want to use one sample for quantization calibration and multiple samples for perplexity evaluation. In this case, the process should be split into two runs. In the 1st run, the model is compiled using one sample. In the 2nd run, the user can provide a different configuration for QNN device execution. Example: ```bash # 1st run to compile with --limit 1 -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1 --compile_only +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --run_lm_eval --tasks wikitext --limit 1 --compile_only ``` ```bash # 2nd run to perform QNN device execution with --limit 3 -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --run_lm_eval --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json ``` #### Tasks quantization calibration If `--tasks ${TASK}` is not provided, the program will use `--prompt ${PROMPT}` as the dataset for quantization calibration. -Regardless of whether `--eval_perplexity` is provided, as long as `--tasks ${TASK}` is specified, the specified tasks will be used for model quantization calibration instead of the prompt. +Regardless of whether `--run_lm_eval` is provided, as long as `--tasks ${TASK}` is specified, the specified tasks will be used for model quantization calibration instead of the prompt. diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index e2407e6812a..4e7c4b9be46 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -9,26 +9,18 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Callable, Dict, Tuple, Type - -import torch -from executorch.backends.qualcomm.quantizer.custom_annotation import ( - annotate_down_proj, - annotate_kv_8bit, - annotate_output_16a8w, - annotate_qkv_proj_sha, - StaticLLMQuantConfig, -) -from executorch.backends.qualcomm.quantizer.qconfig import ( - get_ptq_per_channel_quant_config, -) -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from typing import Callable, Dict, Type + from executorch.examples.models.codegen import ( convert_weights as convert_codegen_weights, ) - from executorch.examples.models.gemma import convert_weights as convert_gemma_weights from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights + +from executorch.examples.models.glm import convert_weights as convert_glm_weights +from executorch.examples.models.granite import ( + convert_weights as convert_granite_weights, +) from executorch.examples.models.phi_4_mini import ( convert_weights as convert_phi_4_mini_weights, ) @@ -49,8 +41,27 @@ from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( MultiScopeAwareLlamaModel, ) + +from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import ( + CodegenQuantRecipe, + Gemma3QuantRecipe, + Gemma_2BQuantRecipe, + GLM_1_5B_InstructQuantRecipe, + Granite_3_3_2B_InstructQuantRecipe, + Llama3_1BQuantRecipe, + Llama3_3BQuantRecipe, + LlamaStories110MQuantRecipe, + LlamaStories260KQuantRecipe, + Phi4MiniQuantRecipe, + Qwen2_5_0_5BQuantRecipe, + Qwen2_5_1_5BQuantRecipe, + Qwen3_0_6BQuantRecipe, + Qwen3_1_7BQuantRecipe, + Smollm2QuantRecipe, + Smollm3QuantRecipe, + StaticLLMQuantRecipe, +) from tabulate import tabulate -from torchao.quantization.pt2e import MinMaxObserver BASE_DIR = os.path.dirname(__file__) @@ -59,15 +70,6 @@ LLM_VARIANT_ARCHS = { "gemma3-1b": MultiScopeAwareLlamaModel, } -annotate_wqkv_sha = partial( - annotate_qkv_proj_sha, - qkv_tags={ - StaticLLMQuantConfig.wq_sha, - StaticLLMQuantConfig.wk_sha, - StaticLLMQuantConfig.wv_sha, - }, -) -annotate_wv_sha = partial(annotate_qkv_proj_sha, qkv_tags={StaticLLMQuantConfig.wv_sha}) @dataclass(init=False, frozen=True) @@ -83,8 +85,6 @@ class LLMModelConfig(ABC): transform_weight: Set to true to change Hugging Face weight to improve the performance of RoPE in HTP backend. instruct_model: True if the model uses chat templates. Check Hugging Face model card to ensure the model uses chat templates. num_sharding: Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers. - ptq: Set to true to perform PTQ quantization. Support 16a16w, 16a8w, 16a4w, 16a4w_block, 8a8w. - group_size: Group size used in block quantization for weight quantization. Will only be used when ptq = 16a4w_block masked_softmax: The MaskedSoftmax feature is designed to optimize the LLMs accuracy and performance executed on HTP backend. MaskedSoftmax is used to replace the Softmax(Add(In, Mask)) structure in attention block in LLMs during backend optimization. For more details, please refer to QNN documents. Note that it is only supported starting from QNN 2.35. @@ -93,7 +93,7 @@ class LLMModelConfig(ABC): r1: Enable SpinQuant R1 quantization optimization. r2: Enable SpinQuant R2 quantization optimization. r3: Enable SpinQuant R3 quantization optimization. - custom_annotation: Custom annotation to use when setting quant configs for the model. + quant_recipe: Quantization recipe to use when setting quant configs for the model. """ repo_id: str @@ -104,14 +104,12 @@ class LLMModelConfig(ABC): transform_weight: bool instruct_model: bool num_sharding: int - ptq: QuantDtype - group_size: int masked_softmax: bool seq_mse_candidates: int r1: bool r2: bool r3: bool - custom_annotation: Tuple + quant_recipe: StaticLLMQuantRecipe def __str__(self): # noqa: C901 """ @@ -157,22 +155,6 @@ def format_value(v): table = [(k, v) for k, v in attrs.items()] return tabulate(table, headers=["Config", "Value"], tablefmt="grid") - def get_kv_io_bit_width(self) -> int: - if self.ptq is None: - return 32 - elif ( - self.ptq == QuantDtype.use_8a8w - or annotate_kv_8bit in self.custom_annotation - ): - return 8 - else: - # If quantized but not 8a8w or mix_quantization, it has to be 16bit kv io. - return 16 - - def get_logits_output_bit_width(self) -> int: - # We use 16bit logits for all quant config - return 32 if self.ptq is None else 16 - SUPPORTED_LLM_MODELS: Dict[str, LLMModelConfig] = {} @@ -194,27 +176,13 @@ class LlamaStories260K(LLMModelConfig): convert_weights = None transform_weight = True instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w - group_size = None masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) + quant_recipe = LlamaStories260KQuantRecipe @register_llm_model("stories110m") @@ -225,27 +193,13 @@ class LlamaStories110M(LLMModelConfig): convert_weights = None transform_weight = True instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w - group_size = None masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) + quant_recipe = LlamaStories110MQuantRecipe @register_llm_model("llama3_2-1b_instruct") @@ -257,26 +211,13 @@ class Llama3_2_1B_Instruct(LLMModelConfig): transform_weight = True # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = False seq_mse_candidates = 1000 r1 = False r2 = False r3 = False - quantization_config_down_proj_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial( - annotate_down_proj, quantization_config=quantization_config_down_proj_16a8w - ), - ) + quant_recipe = Llama3_1BQuantRecipe @register_llm_model("llama3_2-3b_instruct") @@ -288,20 +229,32 @@ class Llama3_2_3B_Instruct(LLMModelConfig): transform_weight = True # The Llama3_2 enabled should be instruct, however, Llama's tokenizer does not provide utility to apply chat template. instruct_model = False - num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, + quant_recipe = Llama3_3BQuantRecipe + + +@register_llm_model("codegen2_1b") +@dataclass(init=False, frozen=True) +class Codegen(LLMModelConfig): + repo_id: str = "Salesforce/codegen2-1B_P" + params_path: str = os.path.join( + BASE_DIR, "../../../models/codegen/config/config.json" ) + convert_weights = convert_codegen_weights + transform_weight = True + instruct_model = False + num_sharding = 1 + masked_softmax = True + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + quant_recipe = CodegenQuantRecipe @register_llm_model("gemma-2b") @@ -316,73 +269,70 @@ class Gemma_2B(LLMModelConfig): instruct_model = True num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 64 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), + quant_recipe = Gemma_2BQuantRecipe + + +@register_llm_model("gemma3-1b") +@dataclass(init=False, frozen=True) +class Gemma3(LLMModelConfig): + repo_id: str = "google/gemma-3-1b-it" + params_path: str = os.path.join( + BASE_DIR, "../../../models/gemma3/config/1b_config.json" ) + convert_weights = convert_gemma3_weights + transform_weight = False + instruct_model = True + num_sharding = 1 + masked_softmax = True + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + quant_recipe = Gemma3QuantRecipe -@register_llm_model("codegen2_1b") +@register_llm_model("glm-1_5b") @dataclass(init=False, frozen=True) -class Codegen(LLMModelConfig): - repo_id: str = "Salesforce/codegen2-1B_P" +class GLM_1_5B(LLMModelConfig): + repo_id: str = "THUDM/glm-edge-1.5b-chat" params_path: str = os.path.join( - BASE_DIR, "../../../models/codegen/config/config.json" + BASE_DIR, "../../../models/glm/config/1_5b_config.json" ) - convert_weights = convert_codegen_weights + convert_weights = convert_glm_weights transform_weight = True - instruct_model = False + instruct_model = True num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a8w - group_size = None - masked_softmax = True + group_size = 32 + masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - custom_annotation = () + quant_recipe = GLM_1_5B_InstructQuantRecipe -@register_llm_model("gemma3-1b") +@register_llm_model("granite_3_3-2b_instruct") @dataclass(init=False, frozen=True) -class Gemma3(LLMModelConfig): - repo_id: str = "google/gemma-3-1b-it" +class Granite_3_3_2b_Instruct(LLMModelConfig): + repo_id: str = "ibm-granite/granite-3.3-2b-instruct" params_path: str = os.path.join( - BASE_DIR, "../../../models/gemma3/config/1b_config.json" + BASE_DIR, "../../../models/granite/config/2b_config.json" ) - convert_weights = convert_gemma3_weights + convert_weights = convert_granite_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 64 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), - ) + quant_recipe = Granite_3_3_2B_InstructQuantRecipe @register_llm_model("phi_4_mini") @@ -395,27 +345,13 @@ class Phi4Mini(LLMModelConfig): convert_weights = convert_phi_4_mini_weights transform_weight = False instruct_model = True - num_sharding = 8 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) + quant_recipe = Phi4MiniQuantRecipe @register_llm_model("qwen2_5-0_5b") @@ -428,17 +364,13 @@ class Qwen2_5_0_5B(LLMModelConfig): convert_weights = convert_qwen2_5_weights transform_weight = False instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = True - custom_annotation = () + quant_recipe = Qwen2_5_0_5BQuantRecipe @register_llm_model("qwen2_5-1_5b") @@ -451,17 +383,13 @@ class Qwen2_5_1_5B(LLMModelConfig): convert_weights = convert_qwen2_5_weights transform_weight = False instruct_model = False - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = True - custom_annotation = (annotate_output_16a8w,) + quant_recipe = Qwen2_5_1_5BQuantRecipe @register_llm_model("qwen3-0_6b") @@ -474,24 +402,13 @@ class Qwen3_0_6B(LLMModelConfig): convert_weights = convert_qwen3_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = True seq_mse_candidates = 1000 r1 = False r2 = False r3 = False - quantization_config_down_proj_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - partial( - annotate_down_proj, quantization_config=quantization_config_down_proj_16a8w - ), - ) + quant_recipe = Qwen3_0_6BQuantRecipe @register_llm_model("qwen3-1_7b") @@ -504,20 +421,13 @@ class Qwen3_1_7B(LLMModelConfig): convert_weights = convert_qwen3_weights transform_weight = False instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = True - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - ) + quant_recipe = Qwen3_1_7BQuantRecipe @register_llm_model("smollm2_135m") @@ -530,17 +440,13 @@ class Smollm2_135M(LLMModelConfig): convert_weights = convert_smollm2_weights transform_weight = True instruct_model = True - num_sharding = 1 - # quant config - ptq = QuantDtype.use_16a8w - group_size = None masked_softmax = False seq_mse_candidates = 0 r1 = False r2 = False r3 = False - custom_annotation = () + quant_recipe = Smollm2QuantRecipe @register_llm_model("smollm3-3b") @@ -551,23 +457,10 @@ class Smollm3_3B(LLMModelConfig): convert_weights = convert_smollm3_weights transform_weight = False instruct_model = True - num_sharding = 4 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 32 masked_softmax = True seq_mse_candidates = 0 r1 = False r2 = False r3 = False - quantization_config_wqkv_sha_16a8w = get_ptq_per_channel_quant_config( - torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial( - annotate_wqkv_sha, quantization_config=quantization_config_wqkv_sha_16a8w - ), - ) + quant_recipe = Smollm3QuantRecipe diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index c7e7c0cb944..f6f0dc3067f 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -16,6 +16,7 @@ "stories110m": "llama2", "gemma-2b": "gemma", "gemma3-1b": "gemma3", + "granite_3_3-2b_instruct": "granite", "phi_4_mini": "phi_4_mini", "llama3_2-1b_instruct": "llama3", "llama3_2-3b_instruct": "llama3", @@ -26,4 +27,5 @@ "smollm2_135m": "smollm2_135m", "smollm3-3b": "smollm3", "codegen2_1b": "codegen", + "glm-1_5b": "glm", } diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 085e2a6c07e..d41d9d32120 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -276,7 +276,7 @@ def __init__( # noqa: C901 with open(pte_path, "rb") as f: program_data = f.read() - program = deserialize_pte_binary(program_data) + program = deserialize_pte_binary(program_data).program # Retrieve vocab_size from get_metadata under static_llama that is passed to edge manager self.output_vocab_size = None @@ -867,7 +867,7 @@ def graph_module_inference( num_fewshot=num_fewshot, limit=tasks_limit, ) - logging.info(f"Perplexity evaluation summary for {event_name}") + logging.info(f"Evaluation summary for {event_name}") for task, res in eval_results["results"].items(): logging.info(f"{task}: {res}") diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 9af9cdf9549..a21c45c2017 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -316,7 +316,9 @@ def eval_llm(args): if args.ptq is not None: quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") decoder_model_config = SUPPORTED_LLM_MODELS[args.decoder_model] - custom_annotations = decoder_model_config.custom_annotation + custom_annotations = ( + decoder_model_config.quant_recipe().recipe.custom_quant_annotations + ) quantizer = make_custom_quantizer( quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 91d82531654..0847f93d98f 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -85,6 +85,9 @@ set_scales, WrappedLlamaModel, ) +from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import ( + StaticLLMQuantRecipe, +) from executorch.examples.qualcomm.utils import ( make_output_dir, @@ -220,36 +223,20 @@ def quantize( quant_dtype, args, tokenizer, - custom_annotations=(), + quant_recipe, scales_state_dict=None, chat_template=None, lookahead_config=None, ): self.quant_dtype = quant_dtype - quantizer = make_custom_quantizer( - quant_dtype, args.range_setting, custom_annotations - ) + quantizer = make_custom_quantizer(quant_dtype, args.range_setting, ()) self.has_quant_io = True fx_graph_module = None with torch.no_grad(): fx_graph_module = torch.export.export( self.llama_graph_module, self.inputs, strict=True ).module() - - if quant_dtype == QuantDtype.use_16a4w_block: - if self.decoder_model_config.group_size is None: - raise ValueError( - "Group size is required when use quant_dtype 16a4w_block" - ) - conv_nodes = [ - n for n in fx_graph_module.graph.nodes if "conv" in n.name - ] - block_size_map = { - n.name: (1, self.decoder_model_config.group_size, 1, 1) - for n in conv_nodes - } - quantizer.set_block_size_map(block_size_map) - + quantizer.recipe = quant_recipe fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") @@ -302,8 +289,8 @@ def quantize( if args.verbose: logging.info("Verifying the QDQ model...") - # qdq cpu ppl evaluation is time consuming, only enable when eval_perplexity - if args.eval_perplexity: + # qdq cpu ppl evaluation is time consuming, only enable when run_lm_eval + if args.run_lm_eval: # Check qdq cpu results graph_module_inference( use_kv_cache=self.llama_meta["get_use_kv_cache"], @@ -439,12 +426,16 @@ def compile( with open(params_path) as f: kv_config = ModelArgs(**json.load(f)) + # get quant recipe + quant_recipe: StaticLLMQuantRecipe = decoder_model_config.quant_recipe(True) + # TODO: support batch inputs if necessary kv_config.max_batch_size = 1 kv_config.max_seq_len = args.max_seq_len kv_config.use_kv_cache = True kv_config.enable_r3 = decoder_model_config.r3 - kv_config.kv_io_bit_width = decoder_model_config.get_kv_io_bit_width() + kv_config.kv_io_bit_width = quant_recipe.get_kv_io_bit_width() + if decoder_model_config.masked_softmax: if is_qnn_sdk_version_less_than("2.35"): logging.warning( @@ -643,7 +634,7 @@ def permute(w, heads, partial_rotary_dim): QuantDtype.use_8a8w: (8, 8), QuantDtype.use_16a4w: (16, 4), QuantDtype.use_16a4w_block: (16, 4), - }[decoder_model_config.ptq] + }[quant_recipe.default_quant_dtype] scales_state_dict = compute_scales( wrapped_model, tokens, weight_bits, act_bits, 1600 ) @@ -661,24 +652,24 @@ def permute(w, heads, partial_rotary_dim): use_fp16 = True # "io_type" here refers to logits output and "kv_type" refers to kv_cache input/output. fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} - if decoder_model_config.ptq: - if decoder_model_config.get_kv_io_bit_width() == 8: + if quant_recipe.default_quant_dtype: + if quant_recipe.get_kv_io_bit_width() == 8: fixed_point_type["kv_type"] = torch.uint8 - elif decoder_model_config.get_kv_io_bit_width() == 16: + elif quant_recipe.get_kv_io_bit_width() == 16: fixed_point_type["kv_type"] = torch.uint16 else: raise RuntimeError( - f"Unknown kv io bit width {decoder_model_config.get_kv_io_bit_width()}" + f"Unknown kv io bit width {quant_recipe.get_kv_io_bit_width()}" ) - if decoder_model_config.get_logits_output_bit_width() == 16: + if quant_recipe.get_logits_output_bit_width() == 16: fixed_point_type["io_type"] = torch.uint16 else: raise RuntimeError( - f"Unknown logits io bit width {decoder_model_config.get_logits_output_bit_width()}" + f"Unknown logits io bit width {quant_recipe.get_logits_output_bit_width()}" ) - quant_dtype = decoder_model_config.ptq + quant_dtype = quant_recipe.default_quant_dtype if args.dtype_override is not None: dtype_override = DType[args.dtype_override] @@ -701,9 +692,8 @@ def permute(w, heads, partial_rotary_dim): QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ]["skip_node"] = {"tokens"} - if decoder_model_config.ptq: + if quant_recipe.default_quant_dtype: start_quantize_ts = time.time() - custom_annotations = decoder_model_config.custom_annotation kv_quant_attrs = {} for i, llama_instance in enumerate(llama_instance_list): lookahead_config = ( @@ -711,11 +701,12 @@ def permute(w, heads, partial_rotary_dim): if i == 0 and args.model_mode == "lookahead" else None ) + llama_instance.quantize( quant_dtype=quant_dtype, args=args, tokenizer=tokenizer, - custom_annotations=custom_annotations, + quant_recipe=quant_recipe, scales_state_dict=scales_state_dict, chat_template=chat_template, lookahead_config=lookahead_config, @@ -729,11 +720,11 @@ def permute(w, heads, partial_rotary_dim): kv_quant_attrs[output_indices] = output.args[1:] output_indices += 1 break - custom_annotations = custom_annotations + ( + quant_recipe.recipe.custom_quant_annotations.append( partial( annotate_prefill_kv_output, kv_quant_attrs=kv_quant_attrs, - ), + ) ) # temporarily remove annotate_prefill_kv_output llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ @@ -892,7 +883,7 @@ def inference( else f"{args.artifact}/{pte_filename}.pte" ) - if args.eval_perplexity: + if args.run_lm_eval: # Generate the eval wrapper eval_wrapper = QnnRunnerEvalWrapper( args=args, @@ -911,21 +902,41 @@ def inference( ) if args.ip and args.port != -1: - assert ( - len(args.tasks) == 1 and args.tasks[0] == "wikitext" - ), "CI currently supports wikitext only" - wiki_ppl = eval_results["results"][args.tasks[0]]["word_perplexity,none"] - pte_size = os.path.getsize(pte_path) - with Client((args.ip, args.port)) as conn: - conn.send( - json.dumps( - { - "wiki_ppl": wiki_ppl, - "pte_size": pte_size, - "inference_speed": eval_wrapper.inference_speed, - } + assert len(args.tasks) == 1, "CI currently supports 1 lm_eval task only." + match args.tasks[0]: + case "wikitext": + wiki_ppl = eval_results["results"][args.tasks[0]][ + "word_perplexity,none" + ] + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "wiki_ppl": wiki_ppl, + "pte_size": pte_size, + "inference_speed": eval_wrapper.inference_speed, + } + ) + ) + case "hellaswag": + acc_norm = eval_results["results"][args.tasks[0]]["acc_norm,none"] + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "acc_norm": acc_norm, + "pte_size": pte_size, + "inference_speed": eval_wrapper.inference_speed, + } + ) + ) + case _: + raise RuntimeError( + "CI currently supports [wikitext, hellaswag] only." ) - ) + else: for task, res in eval_results["results"].items(): logging.info(f"{task}: {res}") @@ -1023,7 +1034,8 @@ def post_process(): runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner", ) # No pregen inputs, input_list is not required - adb.push(inputs=[], files=[runtime_tokenizer_path]) + if not args.skip_push: + adb.push(inputs=[], files=[runtime_tokenizer_path]) adb.execute(custom_runner_cmd=runner_cmd) adb.pull(output_path=args.artifact, callback=post_process) @@ -1052,7 +1064,7 @@ def post_process(): def _build_tasks_parser(parser): parser.add_argument( - "--eval_perplexity", + "--run_lm_eval", help="If enabled, this will use the tasks provided under args.tasks to calibrate the model", action="store_true", default=False, @@ -1139,7 +1151,7 @@ def _build_parser(): parser.add_argument( "--system_prompt", - help="For Llama3. Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", + help="For Llama3/Granite. Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", default="", type=str, ) @@ -1233,9 +1245,9 @@ def _build_parser(): def export_llama(args) -> None: if args.compile_only and args.pre_gen_pte: raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") - if args.eval_perplexity and args.model_mode != "kv": + if args.run_lm_eval and args.model_mode != "kv": raise RuntimeError("Eval device perplexity is only supported for KV mode") - if args.eval_perplexity and args.tasks is None: + if args.run_lm_eval and args.tasks is None: raise RuntimeError("Please provide --tasks to eval perplexity") assert ( args.decoder_model in SUPPORTED_LLM_MODELS @@ -1297,15 +1309,23 @@ def export_llama(args) -> None: # For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json. runtime_tokenizer_path = tokenizer_artifacts[-3] else: + if args.decoder_model == "glm-1_5b": + with open(tokenizer_config, "r+") as file: + data = json.load(file) + # Verified with HF flow and it uses <|user|> as eos condition + data["bos_token"] = "<|user|>" + data["eos_token"] = "<|user|>" + file.seek(0) + json.dump(data, file, indent=4) + file.truncate() runtime_tokenizer_path = tokenizer_artifacts[-1] + tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config) if args.decoder_model == "codegen2_1b": # Override the default BOS and EOS token IDs for codegen2_1b tokenizer.bos_id = 1 tokenizer.eos_id = 2 - - # TODO: Remove this once error is resolved. elif args.decoder_model == "phi_4_mini": with open(runtime_tokenizer_path, "r+") as file: data = json.load(file) diff --git a/examples/qualcomm/oss_scripts/llama/model/feed_forward.py b/examples/qualcomm/oss_scripts/llama/model/feed_forward.py index 062123b52cc..2f36779cc71 100644 --- a/examples/qualcomm/oss_scripts/llama/model/feed_forward.py +++ b/examples/qualcomm/oss_scripts/llama/model/feed_forward.py @@ -88,3 +88,54 @@ def forward(self, x): hidden_states = self.act(hidden_states) hidden_states = self.fc_out(hidden_states) return hidden_states + + +@register_feed_forward("GlmForCausalLM") +class GLMFeedForward(FeedForwardBase): + """FeedForward with gate_up_proj and down_proj""" + + def __init__(self, args: ModelArgs): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + + assert args.hidden_dim is not None + self.dim = args.dim + self.hidden_dim = args.hidden_dim + + self.gate_up_proj = torch.nn.Linear(args.dim, 2 * args.hidden_dim, bias=False) + self.down_proj = torch.nn.Linear(args.hidden_dim, args.dim, bias=False) + self.activation_fn = args.act_fn.get_function() + + def prepare_feedfoward_conv(self): + self.gate_up_proj_conv = torch.nn.Conv2d( + self.dim, 2 * self.hidden_dim, 1, bias=False + ) + self.down_proj_conv = torch.nn.Conv2d(self.hidden_dim, self.dim, 1, bias=False) + + self.forward_no_conv = self.forward + self.forward = self.forward_feedfoward_conv + + self.gate_up_proj_conv.weight.data.copy_( + self.gate_up_proj.weight[:, :, None, None] + ) + self.down_proj_conv.weight.data.copy_(self.down_proj.weight[:, :, None, None]) + + del self.gate_up_proj + del self.down_proj + + def forward_feedfoward_conv(self, x): + bsz, _, _ = x.size() + x = torch.reshape(x, (bsz, -1, 1, self.dim)) + x = x.transpose(1, 3) # Transpose right before and after Conv + up_states = self.gate_up_proj_conv(x) + gate, up_states = up_states.chunk(2, dim=1) + up_states = up_states * self.activation_fn(gate) + x = self.down_proj_conv(up_states) + x = x.transpose(1, 3) + x = torch.reshape(x, (bsz, -1, self.dim)) + return x + + def forward(self, x): + up_states = self.gate_up_proj(x) + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + return self.down_proj(up_states) diff --git a/examples/qualcomm/oss_scripts/llama/model/layernorm.py b/examples/qualcomm/oss_scripts/llama/model/layernorm.py index a6c12920ed8..7db14bdfd01 100644 --- a/examples/qualcomm/oss_scripts/llama/model/layernorm.py +++ b/examples/qualcomm/oss_scripts/llama/model/layernorm.py @@ -42,6 +42,7 @@ def __init__(self, hidden_size: int, eps=1e-5): super().__init__(hidden_size, eps=eps) +@register_norm("gemma3") @register_norm("rmsnorm") class RMSNorm(torch.nn.RMSNorm, Norm): def __init__(self, hidden_size: int, eps=1e-5): diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index ba2d33d7890..65cf71e0480 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -81,7 +81,11 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals self.attn_softmax = torch.nn.Softmax(dim=-1) - self.scale = float(self.head_dim) ** 0.5 + self.scale = ( + float(self.head_dim) ** 0.5 + if config.attention_multiplier is None + else 1.0 / config.attention_multiplier + ) if getattr(config, "enable_r3", False): self.register_buffer( @@ -349,7 +353,6 @@ def prepare_feedfoward_conv(self): self.forward_no_conv = self.forward self.forward = self.forward_feedfoward_conv - self.w1_conv.weight.data.copy_(self.w1.weight[:, :, None, None]) self.w2_conv.weight.data.copy_(self.w2.weight[:, :, None, None]) self.w3_conv.weight.data.copy_(self.w3.weight[:, :, None, None]) @@ -398,6 +401,7 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals if config.post_attention_norm else None ) + self.residual_multiplier = config.residual_multiplier self.post_ffn_norm = ( torch.nn.RMSNorm(config.dim, eps=config.norm_eps) if config.post_ffn_norm @@ -425,12 +429,20 @@ def forward( ) if self.post_attention_norm: h = self.post_attention_norm(h) - h = x + h + h = ( + x + h * self.residual_multiplier + if self.residual_multiplier is not None + else x + h + ) hidden_states = hidden_states if self.ffn_norm is None else self.ffn_norm(h) out = self.feed_forward(hidden_states) if self.post_ffn_norm: out = self.post_ffn_norm(out) - output = h + out + output = ( + h + out * self.residual_multiplier + if self.residual_multiplier is not None + else h + out + ) return output, k_cache, v_cache @@ -462,6 +474,7 @@ def __init__( self.use_i64_token = use_i64_token self.output_cache = output_cache self.kv_io_bit_width = config.kv_io_bit_width + self.logits_scaling = config.logits_scaling self.layers = nn.ModuleList( [ @@ -549,6 +562,9 @@ def forward( hidden_states = self.norm(hidden_states) logits = self.output(hidden_states) + if self.logits_scaling: + logits = logits / self.logits_scaling + if self.output_cache: return logits, output_k_cache, output_v_cache return logits diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 52796e886fd..af260242316 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -10,8 +10,8 @@ * @file * * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B, - * phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, SmolLM2 135M, - * SmolLM3 3B with Qualcomm AI Engine Direct. + * Granite3.3 2B, phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, + * SmolLM2 135M, SmolLM3 3B with Qualcomm AI Engine Direct. * */ @@ -130,6 +130,17 @@ std::string get_formatted_prompt( formatted_prompt.append("\n"); } break; + case example::DecoderModelVersion::kGranite: + if (!system_prompt.empty()) { + formatted_prompt.append("<|start_of_role|>system<|end_of_role|>"); + formatted_prompt.append(system_prompt); + formatted_prompt.append("<|end_of_text|>\n"); + } + formatted_prompt.append("<|start_of_role|>user<|end_of_role|>"); + formatted_prompt.append(prompt); + formatted_prompt.append("<|end_of_text|>\n"); + formatted_prompt.append("<|start_of_role|>assistant<|end_of_role|>"); + break; case example::DecoderModelVersion::kPhi4: if (!system_prompt.empty()) { formatted_prompt.append("<|system|>"); @@ -172,6 +183,15 @@ std::string get_formatted_prompt( formatted_prompt.append("<|im_end|>\n"); formatted_prompt.append("<|im_start|>assistant\n"); break; + case example::DecoderModelVersion::kGlm: + formatted_prompt.append("<|user|>\n"); + formatted_prompt.append(prompt); + if (!system_prompt.empty()) { + formatted_prompt.append("<|system|>\n"); + formatted_prompt.append(system_prompt); + } + formatted_prompt.append("<|assistant|>\n"); + break; default: ET_CHECK_MSG(false, "unsupported llama version"); break; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index e239a2a5fe1..e021d5d512f 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -127,6 +127,8 @@ Runner::Runner( } else if (decoder_model_version == "gemma3") { decoder_model_version_ = DecoderModelVersion::kGemma3; cache_mode_ = CacheMode::HybridCache; + } else if (decoder_model_version == "granite") { + decoder_model_version_ = DecoderModelVersion::kGranite; } else if (decoder_model_version == "phi_4_mini") { decoder_model_version_ = DecoderModelVersion::kPhi4; } else if (decoder_model_version == "qwen2_5") { @@ -139,6 +141,8 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kSmollm3; } else if (decoder_model_version == "codegen") { decoder_model_version_ = DecoderModelVersion::kCodegen; + } else if (decoder_model_version == "glm") { + decoder_model_version_ = DecoderModelVersion::kGlm; } else { ET_CHECK_MSG(false, "Unsupported Decoder Model"); } @@ -209,6 +213,8 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kCodegen) { eos_ids->insert(tokenizer_->encode("<|endoftext|>", 0, 0).get()[0]); + } else if (decoder_model_version_ == DecoderModelVersion::kGlm) { + eos_ids->insert(tokenizer_->encode("<|user|>", 0, 0).get()[0]); } // Try avoid getMetadataHelper as it is time consuming. @@ -376,7 +382,22 @@ Error Runner::generate_from_prompt_or_file( stats_.inference_start_ms = time_in_ms(); int32_t seq_len = config.seq_len; - seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_; + if (seq_len > context_len_) { + ET_LOG( + Info, + "Warning: Requested seq_len (%d) exceeds compiled max_seq_len (%d). Clamping to %d.", + seq_len, + context_len_, + context_len_); + seq_len = context_len_; + } else if (seq_len <= 0) { + ET_LOG( + Info, + "Warning: Invalid seq_len (%d). Using compiled max_seq_len (%d).", + seq_len, + context_len_); + seq_len = context_len_; + } int32_t n_bos = (cur_pos_ == 0) ? 1 : 0; // encode the (string) prompt into tokens sequence diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 9cf730c3620..c436d40f20c 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -34,12 +34,14 @@ enum DecoderModelVersion { kLlama3, kGemma, kGemma3, + kGranite, kPhi4, kQwen2_5, kQwen3, kSmollm2_135m, kSmollm3, kCodegen, + kGlm, }; enum KvBitWidth { diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 6775c08bd87..40e8fb1a82d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -323,6 +323,30 @@ Result TokenGenerator::generate( break; } } + + // Check if generation was truncated due to seq_len limit (no EOS token) + if (eos_ids_->count(cur_token) == 0 && pos >= seq_len - 1) { + printf("\n"); + ET_LOG( + Info, + "Warning: Generation stopped at seq_len limit (%d) without reaching EOS token. Response may be incomplete.", + seq_len); + if (seq_len >= metadata_.context_len) { + ET_LOG( + Info, + "- seq_len (%d) already equals compiled max_seq_len (%d). Consider recompiling with larger --max_seq_len.", + seq_len, + metadata_.context_len); + } else { + ET_LOG( + Info, + "- seq_len (%d) is less than compiled max_seq_len (%d). Consider increasing --seq_len (up to %d).", + seq_len, + metadata_.context_len, + metadata_.context_len); + } + } + return pos - start_pos; } // Explicit instantiations diff --git a/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py new file mode 100644 index 00000000000..1736a44e642 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py @@ -0,0 +1,624 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from executorch.backends.qualcomm.quantizer.custom_annotation import annotate_kv_8bit +from executorch.backends.qualcomm.quantizer.quant_recipe import ( + QuantGranularity, + QuantRecipe, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from torchao.quantization.pt2e import MinMaxObserver + + +class StaticLLMQuantRecipe: + """ + Qualcomm's static LLaMA quantization recipe. + """ + + def __init__(self): + self.recipe: Optional[QuantRecipe] = None + + # For IO bitwidth + self.default_quant_dtype = getattr(self, "default_quant_dtype", None) + if self.default_quant_dtype is None: + raise ValueError("default_quant_dtype must be defined in the recipe.") + + def annotate(self, graph_module: torch.fx.GraphModule): + self.recipe.annotate(graph_module) + + def get_kv_io_bit_width(self) -> int: + if self.default_quant_dtype is None: + return 32 + elif ( + self.default_quant_dtype == QuantDtype.use_8a8w + or annotate_kv_8bit in self.recipe.custom_quant_annotations + ): + return 8 + else: + # If quantized but not 8a8w or mix_quantization, it has to be 16bit kv io. + return 16 + + def get_logits_output_bit_width(self) -> int: + # We use 16bit logits for all quant config + return 32 if self.default_quant_dtype is None else 16 + + +class LlamaStories260KQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"layers\..*\.attention\.wv.*"}, + QuantDtype.use_8a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class LlamaStories110MQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"layers\..*\.attention\.wv.*"}, + QuantDtype.use_8a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Llama3_1BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w_block + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + note="default with 16bit activation", + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + note="Annotate with 16a4w block quantization since these layers are not sensitive.", + ) + .add_regex( + { + r"output\.conv", + r"layers\.[0-3]\.feed_forward\.w2_conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + note="Down proj layer is sensitive and should be annotated with 16a8w.", + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Llama3_3BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w_block + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"output\.conv", + r"layers\.2[1-7]\.feed_forward\.w2_conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class CodegenQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a8w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + + +class Gemma_2BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + r"output\.conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Gemma3QuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class GLM_1_5B_InstructQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Granite_3_3_2B_InstructQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Phi4MiniQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + .add_regex( + {r"layers\..*\.attention\.wv.*"}, + QuantDtype.use_8a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Qwen2_5_0_5BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + + +class Qwen2_5_1_5BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + .add_regex( + {r"output\.conv"}, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + + +class Qwen3_0_6BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.feed_forward\.w2_conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + + +class Qwen3_1_7BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 16, 1, 1)}, + ) + .add_regex( + { + r"output\.conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + +class Smollm2QuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a8w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + + +class Smollm3QuantRecipe(StaticLLMQuantRecipe): + + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wq.*", + r"layers\..*\.attention\.wk.*", + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + .add_regex( + { + r"output\.conv", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/runner.h b/examples/qualcomm/qaihub_scripts/llama/runner/runner.h index 9672d6a3586..215930392ba 100644 --- a/examples/qualcomm/qaihub_scripts/llama/runner/runner.h +++ b/examples/qualcomm/qaihub_scripts/llama/runner/runner.h @@ -55,7 +55,7 @@ class Runner { // inference_end_ms: End of inference/generation. long inference_end_ms; // Keep a running total of the time spent in sampling. - long aggregate_sampling_time_ms; + long aggregate_sampling_time_ms = 0; // Token count from prompt int64_t num_prompt_tokens; // Token count from generated (total - prompt) diff --git a/examples/qualcomm/scripts/torchvision_vit.py b/examples/qualcomm/scripts/torchvision_vit.py index 2a428683ec3..ed8dbb792c4 100755 --- a/examples/qualcomm/scripts/torchvision_vit.py +++ b/examples/qualcomm/scripts/torchvision_vit.py @@ -7,12 +7,14 @@ import json import logging import os +from contextlib import contextmanager from multiprocessing.connection import Client import numpy as np import torch +import torch.nn.functional as F from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel from executorch.examples.qualcomm.utils import ( @@ -25,6 +27,56 @@ ) +# Copied from torch/nn/functional.py +# QNN does not have 5D permute optimization. Fuse to a single 4D optimization +# Changed unsqueeze(0).transpose(0, -2).squeeze(-2) to permute(2, 0, 1, 3) +def _in_projection_packed_custom(q, k, v, w, b=None) -> list[torch.Tensor]: + from torch.nn.functional import linear + + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = proj.unflatten(-1, (3, E)).permute(2, 0, 1, 3).contiguous() + # pyrefly: ignore # bad-return + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = kv_proj.unflatten(-1, (2, E)).permute(2, 0, 1, 3).contiguous() + # pyrefly: ignore # bad-return + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + # pyrefly: ignore # bad-return + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +# Context manager to patch temporarily, so it won't affect other users using F._in_projection_packed +@contextmanager +def PermuteInProjectionPacked(): + # Save the original function so it can be restored later + _original_in_projection_packed = F._in_projection_packed + F._in_projection_packed = _in_projection_packed_custom + try: + yield + finally: + F._in_projection_packed = _original_in_projection_packed + + def main(args): # ensure the working directory exist. os.makedirs(args.artifact, exist_ok=True) @@ -44,16 +96,18 @@ def main(args): ) pte_filename = "vit_qnn_q8" - instance = TorchVisionViTModel() - build_executorch_binary( - instance.get_eager_model().eval(), - instance.get_example_inputs(), - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - quant_dtype=QuantDtype.use_8a8w, - shared_buffer=args.shared_buffer, - ) + instance = TorchVisionViTModel().get_eager_model().eval() + + with PermuteInProjectionPacked(): + build_executorch_binary( + instance, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) if args.compile_only: return diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py index dace37e5473..58f2ccf1001 100644 --- a/examples/vulkan/export.py +++ b/examples/vulkan/export.py @@ -10,29 +10,29 @@ import argparse import logging +import os -import backends.vulkan.test.utils as test_utils - +import executorch.backends.vulkan.test.utils as test_utils import torch import torchvision - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.devtools import BundledProgram from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.devtools.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) +from executorch.examples.models import MODEL_NAME_TO_MODEL +from executorch.examples.models.model_factory import EagerModelFactory from executorch.exir import to_edge_transform_and_lower from executorch.extension.export_util.utils import save_pte_program from executorch.extension.pytree import tree_flatten from torch.export import Dim, export -from ..models import MODEL_NAME_TO_MODEL -from ..models.model_factory import EagerModelFactory - FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) +import urllib + def is_vision_model(model_name): if model_name in [ @@ -70,6 +70,38 @@ def get_vision_model_dynamic_shapes(): ) +def get_dog_image_tensor(image_size=224, normalization="imagenet"): + url, filename = ( + "https://github.com/pytorch/hub/raw/master/images/dog.jpg", + "dog.jpg", + ) + try: + urllib.URLopener().retrieve(url, filename) + except: + urllib.request.urlretrieve(url, filename) + + from PIL import Image + from torchvision import transforms + + input_image = Image.open(filename).convert("RGB") + + transforms_list = [ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + ] + if normalization == "imagenet": + transforms_list.append( + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ) + + preprocess = transforms.Compose(transforms_list) + + input_tensor = preprocess(input_image) + input_batch = input_tensor.unsqueeze(0) + input_batch = (input_batch,) + return input_batch + + def init_model(model_name): if model_name == "convnext_small": return torchvision.models.convnext_small() @@ -77,13 +109,29 @@ def init_model(model_name): return torchvision.models.densenet161() if model_name == "shufflenet_v2_x1_0": return torchvision.models.shufflenet_v2_x1_0() + if model_name == "YOLO_NAS_S": + try: + from super_gradients.common.object_names import Models + from super_gradients.training import models + except ImportError: + raise ImportError( + "Please install super-gradients to use the YOLO_NAS_S model." + ) + + return models.get(Models.YOLO_NAS_S, pretrained_weights="coco") return None def get_sample_inputs(model_name): + # Lock the random seed for reproducibility + torch.manual_seed(42) + if is_vision_model(model_name): return get_vision_model_sample_input() + if model_name == "YOLO_NAS_S": + input_batch = get_dog_image_tensor(640) + return input_batch return None @@ -95,7 +143,7 @@ def get_dynamic_shapes(model_name): return None -def main() -> None: +def main() -> None: # noqa: C901 logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -117,6 +165,24 @@ def main() -> None: "False", ) + parser.add_argument( + "--small_texture_limits", + action=argparse.BooleanOptionalAction, + default=False, + help="sets the default texture limit to be (2048, 2048, 2048) which is " + "compatible with more devices (i.e. desktop/laptop GPUs) compared to the " + "default (16384, 16384, 2048) which is more targeted for mobile GPUs. Default " + "is False.", + ) + + parser.add_argument( + "--skip_memory_planning", + action=argparse.BooleanOptionalAction, + default=False, + help="Skips memory planning pass while lowering, which can be used for " + "debugging. Default is False.", + ) + parser.add_argument( "-s", "--strict", @@ -159,6 +225,13 @@ def main() -> None: help="Execute lower_module_and_test_output to validate the model. Default is False", ) + parser.add_argument( + "--save_inputs", + action=argparse.BooleanOptionalAction, + default=False, + help="Whether to save the inputs to the model. Default is False", + ) + args = parser.parse_args() if args.model_name in MODEL_NAME_TO_MODEL: @@ -189,6 +262,10 @@ def main() -> None: if args.force_fp16: compile_options["force_fp16"] = True + if args.skip_memory_planning: + compile_options["skip_memory_planning"] = True + if args.small_texture_limits: + compile_options["small_texture_limits"] = True logging.info(f"Exporting model {args.model_name} with Vulkan delegate") @@ -230,25 +307,18 @@ def main() -> None: atol = 2e-2 rtol = 1e-1 - # Test the model if --test flag is provided - if args.test: - test_result = test_utils.run_and_check_output( - reference_model=model, - executorch_program=exec_prog, - sample_inputs=example_inputs, - atol=atol, - rtol=rtol, - ) + # Save regular program + save_pte_program(exec_prog, output_filename, args.output_dir) + logging.info( + f"Model exported and saved as {output_filename}.pte in {args.output_dir}" + ) - if test_result: - logging.info( - "✓ Model test PASSED - outputs match reference within tolerance" - ) - else: - logging.error("✗ Model test FAILED - outputs do not match reference") - raise RuntimeError( - "Model validation failed: ExecuTorch outputs do not match reference model outputs" - ) + if args.save_inputs: + inputs_flattened, _ = tree_flatten(example_inputs) + for i, input_tensor in enumerate(inputs_flattened): + input_filename = os.path.join(args.output_dir, f"input{i}.bin") + input_tensor.numpy().tofile(input_filename) + f"Model input saved as {input_filename} in {args.output_dir}" if args.bundled: # Create bundled program @@ -287,13 +357,27 @@ def main() -> None: logging.info( f"Bundled program exported and saved as {output_filename}.bpte in {args.output_dir}" ) - else: - # Save regular program - save_pte_program(exec_prog, output_filename, args.output_dir) - logging.info( - f"Model exported and saved as {output_filename}.pte in {args.output_dir}" + + # Test the model if --test flag is provided + if args.test: + test_result = test_utils.run_and_check_output( + reference_model=model, + executorch_program=exec_prog, + sample_inputs=example_inputs, + atol=atol, + rtol=rtol, ) + if test_result: + logging.info( + "✓ Model test PASSED - outputs match reference within tolerance" + ) + else: + logging.error("✗ Model test FAILED - outputs do not match reference") + raise RuntimeError( + "Model validation failed: ExecuTorch outputs do not match reference model outputs" + ) + if __name__ == "__main__": with torch.no_grad(): diff --git a/exir/_serialize/__init__.py b/exir/_serialize/__init__.py index 5a5ec315b7f..242f254ca46 100644 --- a/exir/_serialize/__init__.py +++ b/exir/_serialize/__init__.py @@ -8,6 +8,7 @@ from executorch.exir._serialize._program import ( deserialize_pte_binary as _deserialize_pte_binary, + PTEFile as _PTEFile, serialize_pte_binary as _serialize_pte_binary, ) @@ -15,4 +16,5 @@ __all__ = [ "_deserialize_pte_binary", "_serialize_pte_binary", + "_PTEFile", ] diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index bee5b3438b0..be7bf0bd56f 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -21,7 +21,10 @@ _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) -from executorch.exir._serialize._named_data_store import NamedDataStoreOutput +from executorch.exir._serialize._named_data_store import ( + NamedDataStore, + NamedDataStoreOutput, +) from executorch.exir._serialize.data_serializer import DataEntry @@ -46,6 +49,19 @@ _HEADER_BYTEORDER: Literal["little"] = "little" +@dataclass +class PTEFile: + """ + Wraps together the data required to serialize into a PTE file. + """ + + program: Program + # TODO(lfq): add constant data (currently restored in the program) + # TODO(lfq): update this to List[bytes] + mutable_data: Optional[List[Buffer]] = None + named_data: Optional[NamedDataStoreOutput] = None + + @dataclass class AlignedData: """ @@ -403,19 +419,17 @@ def _extract_named_data( def serialize_pte_binary( - program: Program, + pte_file: PTEFile, *, - mutable_data: Optional[List[Buffer]] = None, extract_delegate_segments: bool = False, segment_alignment: int = 128, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, - named_data: Optional[NamedDataStoreOutput] = None, ) -> Cord: """Returns the runtime binary representation of the given Program. Args: - program: The Program to serialize. + pte_file: PTEFile class containing the program and segments. extract_delegate_segments: Whether to move delegate data blobs from the Program into separate segments, rather than encoding those blobs in the flatbuffer data. When true, will also: @@ -430,8 +444,6 @@ def serialize_pte_binary( delegate_alignment: If provided, the minimum alignment of delegate data in the program. Must be a power of 2. If not provided, uses the value in the schema file. - named_data: If provided, named blobs to be stored in segments - after the PTE file. Returns: The serialized form of the Program, ready for execution by the runtime. """ @@ -442,7 +454,7 @@ def serialize_pte_binary( # Don't modify the original program. # TODO(T144120904): Could avoid yet more huge copies with a more shallow # copy, reusing the actual data blobs. - program = copy.deepcopy(program) + program = copy.deepcopy(pte_file.program) # Store extracted segment data, with any buffer-specific alignment. # This may be constant data, delegate data or named data. @@ -466,9 +478,9 @@ def serialize_pte_binary( # Add to the aggregate segments cord. segments.append(AlignedData(constant_segment_data)) - if mutable_data is not None: + if pte_file.mutable_data is not None: mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( - mutable_data, + pte_file.mutable_data, tensor_alignment=None, # data is copied at Method load so no need to align. ) if len(mutable_segment_data) > 0: @@ -483,8 +495,10 @@ def serialize_pte_binary( if extract_delegate_segments: _extract_delegate_segments(program, segments) - if named_data is not None: - _extract_named_data(program, segments, named_data.buffers, named_data.pte_data) + if pte_file.named_data is not None: + _extract_named_data( + program, segments, pte_file.named_data.buffers, pte_file.named_data.pte_data + ) # Append all segments into a single Cord, adding any necessary padding to ensure that # each segment begins at the required alignment. @@ -575,7 +589,91 @@ def serialize_pte_binary( return pte_data -def _restore_segments(program: Program, segment_data: bytes) -> Program: +def _restore_delegates(program: Program, segments: List[bytes]) -> Program: + """Find and replace the Program's references to these segments, inlining + the data. + + Args: + program: The Program holding non-inlined delegates. Modified in-place. + segments: List of bytes containing the delegate data. Not modified. + + Returns: The Program with delegates restored. + """ + for plan_index, plan in enumerate(program.execution_plan): + for delegate_index, delegate in enumerate(plan.delegates): + if delegate.processed.location == DataLocation.INLINE: + continue + assert delegate.processed.location == DataLocation.SEGMENT + index = delegate.processed.index + if index >= len(segments): + raise ValueError( + f"Plan {plan_index} delegate {delegate_index} " + + f"segment index {index} >= num segments {len(segments)}" + ) + + data_index: int = len(program.backend_delegate_data) + program.backend_delegate_data.append( + BackendDelegateInlineData(data=segments[index]) + ) + delegate.processed = BackendDelegateDataReference( + location=DataLocation.INLINE, index=data_index + ) + return program + + +def _restore_constant_segment( + constant_segment: SubsegmentOffsets, segment_data: bytes +) -> List[Buffer]: + """Convert constant and mutable tensors from a single byte-blob into a list of individual tensors. + + Args: + constant_segment: SubsegmentOffset with the offsets of each tensor. + segment_data: byte data containing the tensors and padding. Not modified. + + Returns: + List[Buffer] containing each tensor in a separate object. + """ + buffers: List[Buffer] = [] + for i in range(len(constant_segment.offsets)): + start_offset = constant_segment.offsets[i] + # Note: this is the original end offset plus any padding between it and the next start offset + end_offset = ( + constant_segment.offsets[i + 1] + if i < len(constant_segment.offsets) - 1 + else len(segment_data) + ) + buffers.append(Buffer(storage=segment_data[start_offset:end_offset])) + return buffers + + +def _restore_named_data( + program: Program, + segments: List[bytes], +) -> NamedDataStoreOutput: + """Moves named data from `segments` and `program` into the + NamedDataStoreOutput class. + + Args: + program: The Program holding named data references. Not modified. + segments: The data containing the segments. Not modified. + """ + named_data_store = NamedDataStore() + for entry in program.named_data: + if entry.segment_index >= len(segments): + raise ValueError( + "Named data segment index " + f"{entry.segment_index} >= num segments {len(segments)}" + ) + named_data_store.add_named_data( + key=entry.key, + data=segments[entry.segment_index], + alignment=1, # Deserialization does not preserve alignment. + tensor_layout=None, # PTE file currently does not serialize this. + ) + return named_data_store.get_named_data_store_output() + + +def _restore_segments(program: Program, segment_data: bytes) -> PTEFile: """Moves segments from `segment_data` into `program`. This should recreate the original Program that the segments were extracted @@ -589,7 +687,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: the preceding data has been stripped off so that the first segment begins at offset zero. Returns: - The Program with segments restored. + PTEFile, containing the Program with delegate and constant segments restored, mutable data segment, and named data segment. """ # Extract the list of segment data blobs, which parallel program.segments. segments: List[bytes] = [] @@ -600,53 +698,51 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: ) segments.append(segment_data[segment.offset : segment.offset + segment.size]) - # Find and replace the Program's references to these segments, inlining the - # data. - for plan_index, plan in enumerate(program.execution_plan): - for delegate_index, delegate in enumerate(plan.delegates): - if delegate.processed.location == DataLocation.INLINE: - continue - assert delegate.processed.location == DataLocation.SEGMENT - index = delegate.processed.index - if index >= len(segments): - raise ValueError( - f"Plan {plan_index} delegate {delegate_index} " - + f"segment index {index} >= num segments {len(segments)}" - ) - - data_index: int = len(program.backend_delegate_data) - program.backend_delegate_data.append( - BackendDelegateInlineData(data=segments[index]) - ) - delegate.processed = BackendDelegateDataReference( - location=DataLocation.INLINE, index=data_index - ) + # Restore delegate segments that weren't inlined previously. + program = _restore_delegates(program, segments) # Replace constants from constant_segment into constant_buffer. if program.constant_segment and len(program.constant_segment.offsets) > 0: - buffers: List[Buffer] = [] - constant_segment = segments[program.constant_segment.segment_index] - for i in range(len(program.constant_segment.offsets)): - start_offset = program.constant_segment.offsets[i] - # Note: this is the original end offset plus any padding between - # it and the next start offset. - end_offset = ( - program.constant_segment.offsets[i + 1] - if i < len(program.constant_segment.offsets) - 1 - else len(constant_segment) + if program.constant_segment.segment_index >= len(segments): + raise ValueError( + f"Constant segment index {program.constant_segment.segment_index} >= num segments {len(segments)}" ) - buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) - program.constant_buffer = buffers + program.constant_buffer = _restore_constant_segment( + program.constant_segment, segments[program.constant_segment.segment_index] + ) program.constant_segment.segment_index = 0 program.constant_segment.offsets = [] - # Clear out the segments list since the original Program didn't have one. + # Extract mutable segments. + mutable_data = None + if program.mutable_data_segments and len(program.mutable_data_segments) > 0: + if len(program.mutable_data_segments) > 1: + raise ValueError("Can't handle more than 1 mutable data segment.") + segment_index = program.mutable_data_segments[0].segment_index + if segment_index >= len(segments): + raise ValueError( + f"Mutable data segment index {segment_index} >= num segments {len(segments)}" + ) + mutable_data = _restore_constant_segment( + program.mutable_data_segments[0], + segments[segment_index], + ) + program.mutable_data_segments = None + + # Extract named data. + named_data = None + if program.named_data: + named_data = _restore_named_data(program, segments) + + # Clear named_data and segments, which are empty pre-serialization. + program.named_data = [] program.segments = [] - return program + return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data) -def deserialize_pte_binary(program_data: bytes) -> Program: - """Returns a Program deserialized from the given runtime binary data.""" + +def deserialize_pte_binary(program_data: bytes) -> PTEFile: + """Returns a PTEFile deserialized from the given runtime binary data.""" program_size = len(program_data) segment_base_offset = 0 @@ -664,8 +760,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program: if segment_base_offset != 0: # Move segment data back into the Program. - program = _restore_segments( + return _restore_segments( program=program, segment_data=program_data[segment_base_offset:] ) - return program + return PTEFile(program=program, mutable_data=None, named_data=None) diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index 789ae89b190..60b6079f4a8 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -8,10 +8,10 @@ from typing import Dict, Optional, Set, Tuple -from executorch.exir._serialize import _serialize_pte_binary - from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._named_data_store import NamedDataStoreOutput + +from executorch.exir._serialize._program import PTEFile, serialize_pte_binary from executorch.exir._serialize.data_serializer import ( DataEntry, DataPayload, @@ -46,14 +46,16 @@ def serialize_for_executorch( pte_data=named_data_store.pte_data, external_data={}, ) - pte: Cord = _serialize_pte_binary( - program=emitter_output.program, - mutable_data=emitter_output.mutable_data, + pte: Cord = serialize_pte_binary( + pte_file=PTEFile( + program=emitter_output.program, + mutable_data=emitter_output.mutable_data, + named_data=pte_named_data, + ), extract_delegate_segments=config.extract_delegate_segments, segment_alignment=config.segment_alignment, constant_tensor_alignment=config.constant_tensor_alignment, delegate_alignment=config.delegate_alignment, - named_data=pte_named_data, ) # Serialize PTD files. diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 80f4b8ca49f..46e8f020a0b 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -8,12 +8,13 @@ # pyre-unsafe import copy +import dataclasses import difflib import json import math import unittest -from typing import List, Sequence +from typing import Dict, List, Sequence from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json from executorch.exir._serialize._named_data_store import NamedDataStoreOutput @@ -23,6 +24,7 @@ _json_to_program, _program_to_json, deserialize_pte_binary, + PTEFile, serialize_pte_binary, ) from executorch.exir._serialize.data_serializer import DataEntry @@ -172,7 +174,7 @@ def constant_segment_with_tensor_alignment( # Extract blobs into constant segment during serialization. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=constant_tensor_alignment, ) @@ -281,13 +283,43 @@ def constant_segment_with_tensor_alignment( ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual( + len(deserialized.program.constant_buffer), len(program.constant_buffer) + ) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) + + def _check_named_data_entries( + self, reference: Dict[str, DataEntry], actual: Dict[str, DataEntry] + ) -> None: + self.assertEqual(reference.keys(), actual.keys()) + SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison. + for key in reference.keys(): + ref_entry = reference[key] + actual_entry = actual[key] + for field in dataclasses.fields(ref_entry): + if field.name not in SKIP_FIELDS: + self.assertEqual( + getattr(ref_entry, field.name), + getattr(actual_entry, field.name), + f"Named data record {key}.{field.name} does not match.", + ) + + def _check_named_data_store_output( + self, reference: NamedDataStoreOutput, actual: NamedDataStoreOutput + ) -> None: + # Check buffers. + self.assertEqual(reference.buffers, actual.buffers) + # Check pte_data. + self._check_named_data_entries(reference.pte_data, actual.pte_data) + # Should be empty. + self.assertEqual(reference.external_data, actual.external_data) def test_canonicalize_delegate_indices(self) -> None: def make_execution_plan( @@ -415,7 +447,7 @@ def test_round_trip_no_header_no_segments(self) -> None: deserializing. """ program = get_test_program() - pte_data = bytes(serialize_pte_binary(program)) + pte_data = bytes(serialize_pte_binary(pte_file=PTEFile(program))) self.assertGreater(len(pte_data), 16) # File magic should be present at the expected offset. @@ -426,10 +458,12 @@ def test_round_trip_no_header_no_segments(self) -> None: self.assertIsNone(eh) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) def test_round_trip_large_buffer_sizes(self) -> None: """Tests that when the non_const_buffer_sizes contains integers @@ -438,8 +472,10 @@ def test_round_trip_large_buffer_sizes(self) -> None: """ program = get_test_program() program.execution_plan[0].non_const_buffer_sizes = [0, 2**48] - flatbuffer_from_py = bytes(serialize_pte_binary(program)) - self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py)) + flatbuffer_from_py = bytes(serialize_pte_binary(pte_file=PTEFile(program))) + self.assert_programs_equal( + program, deserialize_pte_binary(flatbuffer_from_py).program + ) def test_round_trip_no_segments_and_no_header(self) -> None: """Tests that a Program serialized with extract_delegate_segments=True @@ -448,7 +484,11 @@ def test_round_trip_no_segments_and_no_header(self) -> None: the same after serializing and deserializing. """ program = get_test_program() - pte_data = bytes(serialize_pte_binary(program, extract_delegate_segments=True)) + pte_data = bytes( + serialize_pte_binary( + pte_file=PTEFile(program), extract_delegate_segments=True + ) + ) self.assertGreater(len(pte_data), 16) # File magic should be present at the expected offset. @@ -463,10 +503,12 @@ def test_round_trip_no_segments_and_no_header(self) -> None: self.assertEqual(program_with_segments.segments, []) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs should be the same. - self.assert_programs_equal(program, program2) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) @staticmethod def gen_blob_data(size: int, pattern: bytes) -> bytes: @@ -496,7 +538,7 @@ def test_round_trip_with_segments(self) -> None: # Extract the blobs into segments during serialization. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, ) @@ -598,8 +640,10 @@ def test_round_trip_with_segments(self) -> None: # meaning that the segments were moved back to inline. This also # demonstrates that the contents of all segments survived, and weren't # truncated or corrupted. - program2 = deserialize_pte_binary(pte_data) - self.assert_programs_equal(program, program2) + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) def test_no_constants(self) -> None: program = get_test_program() @@ -608,7 +652,7 @@ def test_no_constants(self) -> None: pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, @@ -640,7 +684,7 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None: # Extract the blobs into segments should succeeed. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, ) @@ -655,7 +699,7 @@ def test_unused_inline_delegate_blobs_with_segments(self) -> None: # Should cause serialization to fail. with self.assertRaises(ValueError): serialize_pte_binary( - program, + PTEFile(program=program), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, ) @@ -676,7 +720,7 @@ def test_constant_segment_tensor_alignment_non_power_of_2_fails(self) -> None: # Expect failure as tensor alignment 14 is not a power of 2. with self.assertRaises(ValueError): serialize_pte_binary( - program, + PTEFile(program=program), segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=constant_tensor_alignment, ) @@ -711,11 +755,10 @@ def test_constant_delegate_and_named_data_segments(self) -> None: # Extract the blobs into segments during serialization. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program, named_data=named_data), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, - named_data=named_data, ) ) @@ -884,13 +927,17 @@ def test_constant_delegate_and_named_data_segments(self) -> None: ) # Convert back. - program2 = deserialize_pte_binary(pte_data) + deserialized = deserialize_pte_binary(pte_data) # Programs are the same besides constant_buffer, as deserialization # does not preserve constant segment; padding may be added # during serialization. - self.assertEqual(program2.execution_plan, program.execution_plan) + self.assertEqual(deserialized.program.execution_plan, program.execution_plan) # Number of constant tensors should be the same. - self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + self.assertEqual( + len(deserialized.program.constant_buffer), len(program.constant_buffer) + ) + self.assertEqual(deserialized.mutable_data, None) + self._check_named_data_store_output(deserialized.named_data, named_data) def test_named_data_segments(self) -> None: # Set segment alignment to 12 to test the padding. @@ -918,11 +965,10 @@ def test_named_data_segments(self) -> None: # Serialize the program with named data segments. pte_data = bytes( serialize_pte_binary( - program, + PTEFile(program=program, named_data=named_data), extract_delegate_segments=True, segment_alignment=SEGMENT_ALIGNMENT, constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, - named_data=named_data, ) ) @@ -995,6 +1041,26 @@ def test_named_data_segments(self) -> None: buffers[2], ) + # Test roundtrip + deserialized = deserialize_pte_binary(pte_data) + self.assert_programs_equal(deserialized.program, program) + self.assertEqual(deserialized.mutable_data, None) + self._check_named_data_store_output(deserialized.named_data, named_data) + + # Test re-serialize + pte_data2 = serialize_pte_binary( + PTEFile(program=deserialized.program, named_data=deserialized.named_data), + extract_delegate_segments=True, + segment_alignment=SEGMENT_ALIGNMENT, + constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, + ) + # pte_data2 is not going to be the same as pte_data due to alignment; + # directly test the deserialized one. + deserialized2 = deserialize_pte_binary(bytes(pte_data2)) + self.assert_programs_equal(deserialized2.program, program) + self.assertEqual(deserialized2.mutable_data, None) + self._check_named_data_store_output(deserialized2.named_data, named_data) + # Common data for extended header tests. The two example values should produce # the example data. diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index 6999dadb9f9..9614826c61a 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -57,6 +57,22 @@ class BackendDetails(ABC): """ + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + # Allow direct subclasses of BackendDetails + if cls.__bases__ == (BackendDetails,): + return + + # Forbid subclasses whose ANY parent is already a child of BackendDetails + for base in cls.__bases__: + if issubclass(base, BackendDetails) and base is not BackendDetails: + raise TypeError( + f"ExecuTorch delegate doesn't support nested backend, '{base.__name__}' " + " should be a final backend implementation and should not be subclassed " + f"(attempted by '{cls.__name__}')." + ) + @staticmethod # all backends need to implement this method @enforcedmethod diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index f5913826c17..cf44a3021d4 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -95,9 +95,9 @@ python_unittest( ) runtime.python_library( - name = "qnn_backend_demo", + name = "demo_backend", srcs = [ - "qnn_backend_demo.py", + "demo_backend.py", ], visibility = [ "//executorch/...", @@ -119,7 +119,7 @@ runtime.python_library( "//executorch/test/...", ], deps = [ - ":qnn_backend_demo", + ":demo_backend", "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/backend:partitioner", @@ -153,7 +153,7 @@ runtime.python_library( name = "example_backends", deps = [ ":backend_with_compiler_demo", - ":qnn_backend_demo", + ":demo_backend", ], ) @@ -171,7 +171,7 @@ python_unittest( ":backend_with_compiler_demo", ":hta_partitioner_demo", ":op_partitioner_demo", - ":qnn_backend_demo", + ":demo_backend", "//caffe2:torch", "//caffe2/functorch:functorch_src", "//executorch/exir:delegate", @@ -250,7 +250,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/hypothesis:hypothesis", ":op_partitioner_demo", - ":qnn_backend_demo", + ":demo_backend", "//caffe2:torch", "//executorch/exir:delegate", "//executorch/exir:lib", @@ -273,7 +273,7 @@ python_unittest( ":backend_with_compiler_demo", ":hta_partitioner_demo", ":op_partitioner_demo", - ":qnn_backend_demo", + ":demo_backend", "//caffe2:torch", "//caffe2/functorch:functorch_src", "//executorch/exir:delegate", @@ -303,7 +303,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/hypothesis:hypothesis", ":backend_with_compiler_demo", - ":qnn_backend_demo", + ":demo_backend", "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir:schema", diff --git a/exir/backend/test/qnn_backend_demo.py b/exir/backend/test/demo_backend.py similarity index 96% rename from exir/backend/test/qnn_backend_demo.py rename to exir/backend/test/demo_backend.py index 1823cea79cf..b7575cfe549 100644 --- a/exir/backend/test/qnn_backend_demo.py +++ b/exir/backend/test/demo_backend.py @@ -16,7 +16,7 @@ @final -class QnnBackend(BackendDetails): +class DemoBackend(BackendDetails): @staticmethod def preprocess( edge_program: ExportedProgram, diff --git a/exir/backend/test/hta_partitioner_demo.py b/exir/backend/test/hta_partitioner_demo.py index ec3b0ef3d5d..ba42c50b0f7 100644 --- a/exir/backend/test/hta_partitioner_demo.py +++ b/exir/backend/test/hta_partitioner_demo.py @@ -18,7 +18,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.test.qnn_backend_demo import QnnBackend +from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.backend.utils import tag_constant_data from torch.export import ExportedProgram from torch.fx.passes.infra.partitioner import Partition @@ -28,7 +28,7 @@ class HTAPartitionerMultiplePatternsDemo(Partitioner): """ An example implementation to partition graph for HTA, in this example, the backend - associate with this partitioner is QnnBackend. With QnnBackend, the two lowerable + associate with this partitioner is DemoBackend. With DemoBackend, the two lowerable patterns are: (lstm + conv) and (sub). backend is a class member instead of instance members, as it is a properties of HTAPartitionerMultiplePatternsDemo, and won't be different for different HTAPartitionerMultiplePatternsDemo instances. @@ -116,7 +116,7 @@ def sub(x, y): pattern_sub.graph, ] - backend_id = QnnBackend.__name__ + backend_id = DemoBackend.__name__ self.delegation_spec = DelegationSpec(backend_id, []) def is_exclusive(self, partition_list_list: List[List[Partition]]) -> bool: @@ -269,7 +269,7 @@ def forward(self, x_raw, h, c): ] # Only (lstm + conv) pattern is lowerable - backend_id = QnnBackend.__name__ + backend_id = DemoBackend.__name__ self.delegation_spec = DelegationSpec(backend_id, []) def partition(self, exported_program: ExportedProgram) -> PartitionResult: diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index 42e8ef16bd7..fa124c855db 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -12,6 +12,7 @@ import torch from executorch.exir import to_edge from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( AllNodePartitioner, ) @@ -26,6 +27,7 @@ from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, ) +from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.backend.test.hta_partitioner_demo import ( HTAPartitionerMultiplePatternsDemo, HTAPartitionerOnePatternDemo, @@ -34,7 +36,6 @@ AddAttributePartitionerDemo, AddMulPartitionerDemo, ) -from executorch.exir.backend.test.qnn_backend_demo import QnnBackend from executorch.exir.delegate import executorch_call_delegate from executorch.exir.dialects._ops import ops as exir_ops @@ -642,7 +643,7 @@ def forward(self, x_raw, h, c): self.check_backend_delegate( program=program_with_delegates.program, delegate=program_with_delegates.program.execution_plan[0].delegates[0], - expected_id=QnnBackend.__name__, + expected_id=DemoBackend.__name__, expected_processed=b"imqnncompiled", ) @@ -783,7 +784,7 @@ def forward(self, x_raw, h, c): self.check_backend_delegate( program=program_with_delegates.program, delegate=program_with_delegates.program.execution_plan[0].delegates[0], - expected_id=QnnBackend.__name__, + expected_id=DemoBackend.__name__, expected_processed=b"imqnncompiled", ) @@ -1444,3 +1445,19 @@ def inputs(self): self.assertTrue( torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03) ) + + def test_prohibited_nested_backends(self): + class MyBackend(BackendDetails): + @staticmethod + def preprocess(edge_program, compile_specs): + return None + + with self.assertRaises(TypeError) as ctx: + + class MyOtherBackend(MyBackend): + pass + + self.assertIn( + "'MyBackend' should be a final backend implementation and should not be subclassed (attempted by 'MyOtherBackend')", + str(ctx.exception), + ) diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index ef3b502bfd3..8a2bdaa77a5 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -25,6 +25,7 @@ from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, ) +from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.backend.test.hta_partitioner_demo import ( HTAPartitionerMultiplePatternsDemo, HTAPartitionerOnePatternDemo, @@ -33,7 +34,6 @@ AddAttributePartitionerDemo, AddMulPartitionerDemo, ) -from executorch.exir.backend.test.qnn_backend_demo import QnnBackend from executorch.exir.delegate import executorch_call_delegate from executorch.exir.dialects._ops import ops as exir_ops @@ -664,7 +664,7 @@ def forward(self, x_raw, h, c): delegate=program_with_delegates._emitter_output.program.execution_plan[ 0 ].delegates[0], - expected_id=QnnBackend.__name__, + expected_id=DemoBackend.__name__, expected_processed=b"imqnncompiled", ) @@ -803,7 +803,7 @@ def forward(self, x_raw, h, c): delegate=program_with_delegates._emitter_output.program.execution_plan[ 0 ].delegates[0], - expected_id=QnnBackend.__name__, + expected_id=DemoBackend.__name__, expected_processed=b"imqnncompiled", ) diff --git a/exir/backend/test/test_compatibility.py b/exir/backend/test/test_compatibility.py index 4bde3d40b2c..9b6ae79ba97 100644 --- a/exir/backend/test/test_compatibility.py +++ b/exir/backend/test/test_compatibility.py @@ -8,7 +8,7 @@ import torch from executorch.exir import to_edge -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( AllNodePartitioner, @@ -58,7 +58,7 @@ def forward(self, x): # Generate the .pte file with the wrong version. buff = bytes( _serialize_pte_binary( - program=prog, + pte_file=_PTEFile(program=prog), ) ) @@ -105,7 +105,7 @@ def forward(self, x): # Generate the .pte file with the wrong version. buff = bytes( _serialize_pte_binary( - program=prog, + pte_file=_PTEFile(program=prog), ) ) diff --git a/exir/backend/test/test_debug_handle_map.py b/exir/backend/test/test_debug_handle_map.py index c6d426cf082..a82207239ac 100644 --- a/exir/backend/test/test_debug_handle_map.py +++ b/exir/backend/test/test_debug_handle_map.py @@ -11,8 +11,8 @@ import torch from executorch import exir from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo -from executorch.exir.backend.test.qnn_backend_demo import QnnBackend from executorch.exir.delegate import executorch_call_delegate from hypothesis import given, settings, strategies as st @@ -84,10 +84,10 @@ def test_lowered_the_whole_model(self, unlift): edge_compile_config ) lowered_model = to_backend( - QnnBackend.__name__, edgeir_m.exported_program, [] + DemoBackend.__name__, edgeir_m.exported_program, [] ) - # QnnBackend compile all nodes as one node. The debug_handle_map will be like (1: (debug handle from all nodes)) + # DemoBackend compile all nodes as one node. The debug_handle_map will be like (1: (debug handle from all nodes)) # Ensure there is only one debug identifier self.assertEqual( len(lowered_model.meta["debug_handle_map"].keys()), diff --git a/exir/backend/test/test_lowered_backend_module.py b/exir/backend/test/test_lowered_backend_module.py index 6cdaf92b3d2..06a843df17d 100644 --- a/exir/backend/test/test_lowered_backend_module.py +++ b/exir/backend/test/test_lowered_backend_module.py @@ -16,7 +16,7 @@ from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, ) -from executorch.exir.backend.test.qnn_backend_demo import QnnBackend +from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.schema import DelegateCall, Program from executorch.extension.pybindings.portable_lib import ( # @manual @@ -128,7 +128,7 @@ def test_emit_lowered_backend_module(self): compile_config=edge_compile_config, ) lowered_model = to_backend( - QnnBackend.__name__, edgeir_m.exported_program(), [] + DemoBackend.__name__, edgeir_m.exported_program(), [] ) program = lowered_model.program() reference_program = self.get_program_from_wrapped_module( @@ -181,7 +181,7 @@ def test_emit_nested_lowered_backend_module(self): compile_config=edge_compile_config, ) lowered_module = to_backend( - QnnBackend.__name__, edgeir_m.exported_program(), [] + DemoBackend.__name__, edgeir_m.exported_program(), [] ) # This module will include one operator and two delegate call @@ -200,7 +200,7 @@ def forward(self, *args): ) nested_lowered_model = to_backend( - QnnBackend.__name__, wrapped_module_edge.exported_program(), [] + DemoBackend.__name__, wrapped_module_edge.exported_program(), [] ) program = nested_lowered_model.program() diff --git a/exir/capture/_config.py b/exir/capture/_config.py index b2252e122c9..3fbc8ae7ef3 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -6,7 +6,7 @@ # pyre-unsafe from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch @@ -94,9 +94,14 @@ class ExecutorchBackendConfig: # Moreover, static views will be elided from the ExecuTorch graph remove_view_copy: bool = True - # If set to true, all constant tensors will be stored in a separate file, - # external to the PTE file. - external_constants: bool = False + # Bool: if True, all constant tensors will be stored in a separate file. If False, + # all constant tensors will be stored in the PTE file. + # Callable: a function from torch.fx.Node to Optional[str]. This will be called for each + # placeholder (constant tensor) node, and if it returns a string, that node will be + # tagged with the string. If None, the constant tensor is stored in the PTE file. + # Otherwise, it is stored in a file named by the string. E.g., a function + # lambda x: "model_weights" will save all constants into a file "model_weights.ptd". + external_constants: Union[bool, Callable[[torch.fx.Node], Optional[str]]] = False # If set to true, all trainable weights will be stored in a separate file, # external to the PTE file. diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 199a667ab64..165bc2951f7 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1665,9 +1665,10 @@ def forward(self, x): self.assertEqual(values[5].val, Double(double_val=float("-inf"))) # Confirm that we can also deserialize the model with infinity in it. - pte_data = deserialize_pte_binary(model.buffer) + deserialize = deserialize_pte_binary(model.buffer) self.assertEqual( - pte_data.execution_plan, model.executorch_program.execution_plan + deserialize.program.execution_plan, + model.executorch_program.execution_plan, ) def test_mutate_input_tensor(self) -> None: @@ -1716,9 +1717,38 @@ def forward(self, x): external_map = emitter_output.external_constant_map[ "_default_external_constant" ] + self.assertEqual(len(external_map), 2) self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) + def test_constant_tagged_tensors_custom(self) -> None: + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + model = to_edge( + export(LinearModule(), (torch.ones(5, 5),), strict=True) + ).to_executorch( + config=ExecutorchBackendConfig( + external_constants=lambda x: ( + "linear_weight" if "weight" in x.name else None + ), + ) + ) + emitter_output = model._emitter_output + # constant_buffer contains placeholder and linear bias. + self.assertEqual(len(emitter_output.program.constant_buffer), 2) + # external constant buffer contains linear weight. + self.assertEqual(len(emitter_output.external_constant_buffer), 1) + # The lambda saves all constants to the key 'linear_weight'. + external_map = emitter_output.external_constant_map["linear_weight"] + self.assertEqual(len(external_map), 1) + self.assertEqual(external_map["linear.weight"], 0) + def test_constant_tagged_tensor_dedup(self) -> None: class ConstantModule(nn.Module): def __init__(self): diff --git a/exir/graph_module.py b/exir/graph_module.py index e26d22d8145..2adf62ab0b8 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +11,7 @@ from typing import Callable, Dict, List, Tuple, Union import torch +from torch._ops import HigherOrderOperator LeafValue = Union[ @@ -46,14 +48,15 @@ def _get_submodule( return submod_node.target, submodule, node -def get_control_flow_submodules( +def _get_control_flow_submodules( graph_module: torch.fx.GraphModule, + op_to_submodule_arg_index: dict[HigherOrderOperator, list[int]], ) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: """ Returns a list of submodules used for control flow operations - (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look - into submodules). Specifically, the returned value is a list containing a - tuple of (name of the submodule that's stored in the graph module, the + that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the submodule itself, and the fx node that uses this submodule). """ control_flow_submodules = [] @@ -61,15 +64,50 @@ def get_control_flow_submodules( if node.op != "call_function": continue - if node.target is torch.ops.higher_order.cond: - control_flow_submodules.append(_get_submodule(graph_module, node, 1)) - control_flow_submodules.append(_get_submodule(graph_module, node, 2)) - if node.target is torch.ops.higher_order.map_impl: - control_flow_submodules.append(_get_submodule(graph_module, node, 0)) + for op in op_to_submodule_arg_index: + if node.target is not op: + continue + for i in op_to_submodule_arg_index[op]: + control_flow_submodules.append(_get_submodule(graph_module, node, i)) return control_flow_submodules +def get_control_flow_submodules( + graph_module: torch.fx.GraphModule, +) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + return _get_control_flow_submodules( + graph_module, + {torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]}, + ) + + +def get_cond_while_submodules( + graph_module: torch.fx.GraphModule, +) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: + """ + Returns a list of submodules used for control flow operations + (torch.ops.higher_order.cond/while_loop) that are in the given toplevel graph (does not look + into submodules). Specifically, the returned value is a list containing + tuples of (name of the submodule that's stored in the graph module, the + submodule itself, and the fx node that uses this submodule). + """ + return _get_control_flow_submodules( + graph_module, + { + torch.ops.higher_order.cond: [1, 2], + torch.ops.higher_order.while_loop: [0, 1], + }, + ) + + def bfs_trace_with_node_process( gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None] ) -> None: diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 61414990703..c0ff61242df 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -13,7 +13,7 @@ import torch import torch.utils._pytree as pytree -from executorch.exir._serialize import _serialize_pte_binary +from executorch.exir._serialize import _PTEFile, _serialize_pte_binary from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name @@ -164,12 +164,14 @@ def buffer( # TODO(T181463742): avoid calling bytes(..) which incurs large copies. out = bytes( _serialize_pte_binary( - program=self.program(memory_planning=memory_planning), + pte_file=_PTEFile( + program=self.program(memory_planning=memory_planning), + named_data=self.named_data_store_output, + ), extract_delegate_segments=extract_delegate_segments, segment_alignment=segment_alignment, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, - named_data=self.named_data_store_output, ) ) return out diff --git a/exir/program/_program.py b/exir/program/_program.py index 9298eb3e88d..03c9aeed886 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -79,7 +79,6 @@ EXIREdgeDialectVerifier, get_aten_verifier, ) -from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch._export.verifier import Verifier from torch.export import ExportedProgram @@ -590,6 +589,10 @@ def __init__( self._segment_alignment: int = segment_alignment self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment self._delegate_alignment: Optional[int] = delegate_alignment + from executorch.extension.flat_tensor.serialize.serialize import ( + FlatTensorSerializer, + ) + self._data_serializer: DataSerializer = FlatTensorSerializer() def _get_emitter_output(self) -> EmitterOutput: @@ -836,13 +839,11 @@ def edge_to_executorch_passes( Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. """ passes: List[PassType] = [ - SpecPropPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. *config.passes, - # config.passes may contain external_constants_pass. This pass has to - # run after SpecPropPass, which populates tensor names. + SpecPropPass(), EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), ] + pre_memory_planning_passes(config, name) @@ -1734,11 +1735,19 @@ def to_executorch( # noqa (FLAKE8) C901 # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) - # Extract constants if the config says too. - if config.external_constants: + # Tag constant weights. + if ( + isinstance(config.external_constants, bool) + and config.external_constants + ): new_gm_res = external_constants_pass(new_gm) new_gm = new_gm_res.graph_module - elif config.external_mutable_weights: + elif callable(config.external_constants): + new_gm_res = external_constants_pass(new_gm, config.external_constants) + new_gm = new_gm_res.graph_module + + # Tag mutable weights. + if config.external_mutable_weights: new_gm_res = external_mutable_weights_pass(new_gm, program) new_gm = new_gm_res.graph_module @@ -1839,6 +1848,10 @@ def __init__( ) # Serialize emitter output, ready to be written to a file. + from executorch.extension.flat_tensor.serialize.serialize import ( + FlatTensorSerializer, + ) + self._data_serializer = FlatTensorSerializer() self._pte_data, self._tensor_data = serialize_for_executorch( self._emitter_output, diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index caab322d27b..885f018a1f8 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -2143,17 +2143,23 @@ def deserialize_meta_func(serialized_target: str): def import_nn_module_stack(key, path, ty): return key, (path, ty) - # Helper function that splits strings by commas except for those - # encapsulated by parens, which are valid traces. - # TODO: Currently this is needed due to indexing Sequential - # layers introducing names in the form "layer.slice(1, None, None)". - # If that naming is improved, this fancier splitting can probably be - # reverted to a simple split by comma. + # Helper function to split string by commas, accounting for nested parentheses/brackets def metadata_split(metadata): - # Remove the parentheses and commas inside them - metadata = re.sub(r"\(.*?\)", "", metadata) - # Split the string by comma, except for those inside parentheses - return re.split(r"(?> AsrRunner::transcribe( Info, "Conversion complete, first value = %f", static_cast( - preprocessed_features - ->mutable_data_ptr<::executorch::aten::BFloat16>()[0])); + preprocessed_features->mutable_data_ptr()[0])); } } @@ -223,9 +222,7 @@ Result> AsrRunner::transcribe( ET_LOG( Info, "Encoder first value: %f", - static_cast( - encoder_output_tensor - .mutable_data_ptr<::executorch::aten::BFloat16>()[0])); + static_cast(encoder_output_tensor.mutable_data_ptr()[0])); auto encoder_output_ptr = std::make_shared<::executorch::aten::Tensor>( std::move(encoder_output_tensor)); diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 9a437e7dad5..5dda2318f3f 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -60,5 +60,6 @@ runtime.python_test( ], deps = [ "//caffe2:torch", + "//executorch/extension/pybindings:portable_lib", ], ) diff --git a/extension/llm/custom_ops/test_quantized_sdpa.py b/extension/llm/custom_ops/test_quantized_sdpa.py index 87026d5c251..e6edf6ffbb1 100644 --- a/extension/llm/custom_ops/test_quantized_sdpa.py +++ b/extension/llm/custom_ops/test_quantized_sdpa.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from executorch.extension.llm.custom_ops import custom_ops # noqa +from executorch.extension.pybindings.portable_lib import _unsafe_reset_threadpool def is_fbcode(): @@ -40,6 +41,11 @@ def setUp(self): self.q_shape = None self.kv_shape = None self.is_seq_at_dim_2 = True + # For some reason 4 threads doesnt work + # This setting is needed to make this test not flaky due to OMP + # error of "OMP: Error #131: Thread identifier invalid" + # Not clear why that happens but having smaller threadpool resolves it + _unsafe_reset_threadpool(3) def _scale_tensor(self, tensor, min_value, max_value, scale=True): normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index f8c556f351c..675c0179ebb 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -473,7 +473,11 @@ def to_edge_transform_and_lower( return self def to_executorch( - self, passes: Optional[List[ExportPass]] = None + self, + passes: Optional[List[ExportPass]] = None, + external_constants_tag: Optional[ + Callable[[torch.fx.Node], Optional[str]] + ] = None, ) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. @@ -506,6 +510,7 @@ def to_executorch( do_quant_fusion_and_const_prop=True, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + external_constants=external_constants_tag, ) ) logging.info( diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index 8d280b4eaf9..6a2c1989922 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -55,6 +55,25 @@ target_include_directories( extension_llm_runner INTERFACE ${_common_include_directories} ) +# If the project is configured to build with CUDA support, try to find a CUDA +# runtime (prefer the CUDAToolkit package). If found, expose a compile-time +# macro so sources can conditionally compile CUDA-aware code. +if(EXECUTORCH_BUILD_CUDA) + # Prefer the modern CMake CUDAToolkit module, fall back to searching for the + # CUDA runtime library (cudart) if the package isn't available. + find_package(CUDAToolkit QUIET) + if(CUDAToolkit_FOUND) + target_compile_definitions(extension_llm_runner PUBLIC CUDA_AVAILABLE) + target_link_libraries(extension_llm_runner PUBLIC CUDA::cudart) + message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE") + else() + message( + STATUS + "CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found" + ) + endif() +endif() + install( TARGETS extension_llm_runner EXPORT ExecuTorchTargets diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 674be820072..13f8d7a9db5 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -200,7 +200,8 @@ std::unique_ptr create_text_llm_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::vector data_files, - float temperature) { + float temperature, + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) { // Sanity check tokenizer if (!tokenizer || !tokenizer->is_loaded()) { ET_LOG(Error, "Tokenizer is null or not loaded"); @@ -211,9 +212,13 @@ std::unique_ptr create_text_llm_runner( std::unique_ptr module; if (data_files.size() > 0) { module = std::make_unique( - model_path, data_files, Module::LoadMode::File); + model_path, + data_files, + Module::LoadMode::File, + std::move(event_tracer)); } else { - module = std::make_unique(model_path, Module::LoadMode::File); + module = std::make_unique( + model_path, Module::LoadMode::File, std::move(event_tracer)); } // Get metadata from Module diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index 08f0efd0353..424567b7c2b 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -123,7 +123,8 @@ ET_EXPERIMENTAL std::unique_ptr create_text_llm_runner( const std::string& model_path, std::unique_ptr<::tokenizers::Tokenizer> tokenizer, std::vector data_files = {}, - float temperature = -1.0f); + float temperature = -1.0f, + std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr); /** * @brief Creates a MultimodalRunner instance with dependency injection diff --git a/extension/llm/runner/multimodal_runner.cpp b/extension/llm/runner/multimodal_runner.cpp index 047ca27ee2b..5c0c1e658a7 100644 --- a/extension/llm/runner/multimodal_runner.cpp +++ b/extension/llm/runner/multimodal_runner.cpp @@ -15,6 +15,10 @@ #include #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::extension::llm { using ::executorch::extension::Module; @@ -38,7 +42,16 @@ MultimodalRunner::MultimodalRunner( io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), - pos_(0) {} + pos_(0) { +#ifdef CUDA_AVAILABLE + cuda_memory_tracker_ = + std::make_unique<::executorch::backends::cuda::CudaMemoryTracker>(); + // Probe immediately after creating the tracker to capture GPU state before + // any model loading happens. + stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes(); + stats_->gpu_free_before_load_bytes = cuda_memory_tracker_->last_free_bytes(); +#endif +} bool MultimodalRunner::is_loaded() { return multimodal_prefiller_->is_method_loaded() && @@ -49,8 +62,18 @@ Error MultimodalRunner::load() { if (is_loaded()) { return Error::Ok; } + stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(multimodal_prefiller_->load()); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); + stats_->model_load_end_ms = time_in_ms(); + +#ifdef CUDA_AVAILABLE + cuda_memory_tracker_->log_sample("after_load"); + stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes(); + stats_->gpu_free_after_load_bytes = cuda_memory_tracker_->last_free_bytes(); + stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb(); +#endif + return Error::Ok; } @@ -86,9 +109,7 @@ Error MultimodalRunner::generate( } if (!is_loaded()) { - stats_->model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); - stats_->model_load_end_ms = time_in_ms(); } if (config.warming) { @@ -192,6 +213,15 @@ Error MultimodalRunner::generate( stats_->num_generated_tokens = num_generated_tokens; // Finalize stats and call callback stats_->inference_end_ms = time_in_ms(); + +#ifdef CUDA_AVAILABLE + cuda_memory_tracker_->log_sample("after_generate"); + stats_->gpu_free_after_generate_bytes = + cuda_memory_tracker_->last_free_bytes(); + // update peak in case it changed after generation + stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb(); +#endif + if (!config.warming) { printf("\n"); } diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index caf3c296038..b34b7b05ce7 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -36,6 +36,10 @@ // These are provided for backward compatibility #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch { namespace extension { namespace llm { @@ -150,6 +154,11 @@ class ET_EXPERIMENTAL MultimodalRunner { std::unique_ptr text_token_generator_; std::unique_ptr stats_; +#ifdef CUDA_AVAILABLE + std::unique_ptr<::executorch::backends::cuda::CudaMemoryTracker> + cuda_memory_tracker_; +#endif + // Internal state int64_t pos_; }; diff --git a/extension/llm/runner/stats.h b/extension/llm/runner/stats.h index 19766329ed3..d52e02ccc24 100644 --- a/extension/llm/runner/stats.h +++ b/extension/llm/runner/stats.h @@ -44,11 +44,19 @@ struct ET_EXPERIMENTAL Stats { // inference_end_ms: End of inference/generation. long inference_end_ms; // Keep a running total of the time spent in sampling. - long aggregate_sampling_time_ms; + long aggregate_sampling_time_ms = 0; // Token count from prompt int64_t num_prompt_tokens; // Token count from generated (total - prompt) int64_t num_generated_tokens; + // GPU memory stats (optional; may be zero if not available) + // GPU memory stats (optional). Use sentinel UINT64_MAX / -1.0 to indicate + // "not available". + uint64_t gpu_total_bytes = static_cast(-1); + uint64_t gpu_free_before_load_bytes = static_cast(-1); + uint64_t gpu_free_after_load_bytes = static_cast(-1); + uint64_t gpu_free_after_generate_bytes = static_cast(-1); + double gpu_peak_usage_mb = -1.0; inline void on_sampling_begin() { aggregate_sampling_timer_start_timestamp = time_in_ms(); } @@ -75,6 +83,11 @@ struct ET_EXPERIMENTAL Stats { aggregate_sampling_time_ms = 0; num_prompt_tokens = 0; num_generated_tokens = 0; + gpu_total_bytes = static_cast(-1); + gpu_free_before_load_bytes = static_cast(-1); + gpu_free_after_load_bytes = static_cast(-1); + gpu_free_after_generate_bytes = static_cast(-1); + gpu_peak_usage_mb = -1.0; aggregate_sampling_timer_start_timestamp = 0; } @@ -93,7 +106,29 @@ inline std::string stats_to_json_string(const Stats& stats) { << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," << "\"first_token_ms\":" << stats.first_token_ms << "," << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms - << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" + << ","; + // Only include GPU fields in the JSON if gpu_total_bytes is valid (not + // equal to sentinel -1) + if (stats.gpu_total_bytes != static_cast(-1)) { + ss << "\"gpu_total_bytes\":" << stats.gpu_total_bytes; + if (stats.gpu_free_before_load_bytes != static_cast(-1)) { + ss << ",\"gpu_free_before_load_bytes\":" + << stats.gpu_free_before_load_bytes; + } + if (stats.gpu_free_after_load_bytes != static_cast(-1)) { + ss << ",\"gpu_free_after_load_bytes\":" + << stats.gpu_free_after_load_bytes; + } + if (stats.gpu_free_after_generate_bytes != static_cast(-1)) { + ss << ",\"gpu_free_after_generate_bytes\":" + << stats.gpu_free_after_generate_bytes; + } + if (stats.gpu_peak_usage_mb >= 0.0) { + ss << ",\"gpu_peak_usage_mb\":" << stats.gpu_peak_usage_mb; + } + ss << ","; + } + ss << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; return ss.str(); } @@ -156,6 +191,35 @@ inline void print_report(const Stats& stats) { stats.num_prompt_tokens + stats.num_generated_tokens, (double)stats.aggregate_sampling_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND); + + // GPU memory reporting (only meaningful if GPU fields were populated) + if (stats.gpu_total_bytes != static_cast(-1)) { + ET_LOG( + Info, + "\tGPU total memory: %.2f MB", + stats.gpu_total_bytes / 1024.0 / 1024.0); + if (stats.gpu_free_before_load_bytes != static_cast(-1)) { + ET_LOG( + Info, + "\tGPU free before load: %.2f MB", + stats.gpu_free_before_load_bytes / 1024.0 / 1024.0); + } + if (stats.gpu_free_after_load_bytes != static_cast(-1)) { + ET_LOG( + Info, + "\tGPU free after load: %.2f MB", + stats.gpu_free_after_load_bytes / 1024.0 / 1024.0); + } + if (stats.gpu_free_after_generate_bytes != static_cast(-1)) { + ET_LOG( + Info, + "\tGPU free after generate: %.2f MB", + stats.gpu_free_after_generate_bytes / 1024.0 / 1024.0); + } + if (stats.gpu_peak_usage_mb >= 0.0) { + ET_LOG(Info, "\tGPU peak usage: %.2f MB", stats.gpu_peak_usage_mb); + } + } } } // namespace llm diff --git a/extension/pybindings/portable_lib.py b/extension/pybindings/portable_lib.py index 0982d55b474..27468c8b7b5 100644 --- a/extension/pybindings/portable_lib.py +++ b/extension/pybindings/portable_lib.py @@ -65,6 +65,7 @@ _load_program, # noqa: F401 _load_program_from_buffer, # noqa: F401 _reset_profile_results, # noqa: F401 + _threadpool_get_thread_count, # noqa: F401 _unsafe_reset_threadpool, # noqa: F401 BundledModule, # noqa: F401 ExecuTorchMethod, # noqa: F401 diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index c3cd4ed0b47..eb81bda22f7 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -1558,6 +1558,13 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { }, py::arg("num_threads"), call_guard); + m.def( + "_threadpool_get_thread_count", + []() { + return ::executorch::extension::threadpool::get_threadpool() + ->get_thread_count(); + }, + call_guard); py::class_(m, "ExecuTorchModule") .def( diff --git a/extension/pybindings/pybindings.pyi b/extension/pybindings/pybindings.pyi index a3b75780369..9e5ab6211ce 100644 --- a/extension/pybindings/pybindings.pyi +++ b/extension/pybindings/pybindings.pyi @@ -288,3 +288,12 @@ def _unsafe_reset_threadpool(num_threads: int) -> None: This API is experimental and subject to change without notice. """ ... + +@experimental("This API is experimental and subject to change without notice.") +def _threadpool_get_thread_count() -> int: + """ + .. warning:: + + This API is experimental and subject to change without notice. + """ + ... diff --git a/install_executorch.bat b/install_executorch.bat index e6d5c5db363..50fb6fd9b77 100644 --- a/install_executorch.bat +++ b/install_executorch.bat @@ -1,4 +1,5 @@ @ECHO OFF +setlocal EnableDelayedExpansion rem Copyright (c) Meta Platforms, Inc. and affiliates. rem All rights reserved. @@ -7,9 +8,24 @@ rem This batch file provides a basic functionality similar to the bash script. cd /d "%~dp0" +rem Verify that Git checked out symlinks correctly. Without this the Python install +rem will fail when attempting to copy files from src\executorch. +where git >NUL 2>&1 +if not errorlevel 1 ( + set "GIT_SYMLINKS=" + for /f "usebackq delims=" %%i in (`git config --get core.symlinks 2^>nul`) do set "GIT_SYMLINKS=%%i" + if /I not "!GIT_SYMLINKS!"=="true" ( + echo ExecuTorch requires Git symlink support on Windows. + echo Enable Developer Mode and run: git config --global core.symlinks true + echo Re-clone the repository after enabling symlinks, then rerun install_executorch.bat. + exit /b 1 + ) +) + rem Under windows, it's always python set PYTHON_EXECUTABLE=python "%PYTHON_EXECUTABLE%" install_executorch.py %* -exit /b %ERRORLEVEL% +set "EXIT_CODE=%ERRORLEVEL%" +endlocal & exit /b %EXIT_CODE% diff --git a/install_requirements.py b/install_requirements.py index b84e250cf87..df26badbef3 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -23,6 +23,7 @@ SUPPORTED_CUDA_VERSIONS = ( (12, 6), (12, 8), + (12, 9), (13, 0), ) @@ -130,7 +131,11 @@ def install_optional_example_requirements(use_pytorch_nightly): if use_pytorch_nightly else "torchvision" ), - f"torchaudio==2.8.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torchaudio", + ( + f"torchaudio==2.10.0.{NIGHTLY_VERSION}" + if use_pytorch_nightly + else "torchaudio" + ), ] # Then install domain libraries subprocess.run( diff --git a/kernels/portable/cpu/op__clone_dim_order.cpp b/kernels/portable/cpu/op__clone_dim_order.cpp index 83045768cf2..9a4f68bdc46 100644 --- a/kernels/portable/cpu/op__clone_dim_order.cpp +++ b/kernels/portable/cpu/op__clone_dim_order.cpp @@ -10,6 +10,9 @@ #include #include +#include +#include + namespace torch { namespace executor { namespace native { @@ -19,6 +22,30 @@ using Tensor = executorch::aten::Tensor; template using OptionalArrayRef = executorch::aten::OptionalArrayRef; +namespace { + +/** + * Checks the conditions for fast path direct memcpy. This can be used + * when the output dim order is unchanged. + */ +bool check_fast_path_conditions( + const Tensor& in, + OptionalArrayRef dim_order) { + if (!dim_order.has_value()) { + // No dim order means preserve input dim order. + return true; + } + + auto input_dim_order = in.dim_order(); + return std::equal( + dim_order.value().begin(), + dim_order.value().end(), + input_dim_order.begin(), + input_dim_order.end()); +} + +} // namespace + /** * _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? * dim_order=None, Tensor(a!) out) -> Tensor(a!) @@ -55,13 +82,18 @@ Tensor& _clone_dim_order_out( return out; } - // Select the correct input dtype and copy the tensors. - ET_SWITCH_REALHBBF16_TYPES( - self.scalar_type(), - ctx, - "dim_order_ops::_clone_dim_order.out", - CTYPE, - [&] { _to_dim_order_copy_impl(self, out); }); + // Dispatch to the fast path if we can use direct memcpy. + if (check_fast_path_conditions(self, dim_order)) { + std::memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes()); + } else { + // Select the correct input dtype and copy the tensors. + ET_SWITCH_REALHBBF16_TYPES( + self.scalar_type(), + ctx, + "dim_order_ops::_clone_dim_order.out", + CTYPE, + [&] { _to_dim_order_copy_impl(self, out); }); + } return out; } @@ -77,4 +109,4 @@ Tensor& _clone_dim_order_out( } // namespace native } // namespace executor -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/kernels/portable/cpu/op_constant_pad_nd.cpp b/kernels/portable/cpu/op_constant_pad_nd.cpp index 7da10456e58..7209e8e42e5 100644 --- a/kernels/portable/cpu/op_constant_pad_nd.cpp +++ b/kernels/portable/cpu/op_constant_pad_nd.cpp @@ -30,6 +30,7 @@ void set_all_to_value(CTYPE* out_data, size_t step_len, CTYPE value) { template void apply_padding_to_dim( + KernelRuntimeContext& ctx, size_t ndim, const CTYPE* self_data, IntArrayRef self_sizes, @@ -57,7 +58,20 @@ void apply_padding_to_dim( size_t out_step_len = out_strides[dim]; size_t in_step_len = self_strides[dim]; - for ([[maybe_unused]] const auto i : c10::irange(pad_before)) { + // Do not copy padding beyond the out tensor bounds. + if (pad_before > 0) { + size_t numel = 1; + for (ET_UNUSED const auto i : c10::irange(out_sizes.size())) { + numel *= out_sizes[i]; + } + ET_KERNEL_CHECK_MSG( + ctx, + numel >= pad_before * out_step_len, + InvalidArgument, + /* void */, + "Out tensor is too small for the requested padding."); + } + for (ET_UNUSED const auto i : c10::irange(pad_before)) { set_all_to_value(out_data, out_step_len, value); out_data += out_step_len; } @@ -76,8 +90,9 @@ void apply_padding_to_dim( } // Otherwise, call this function recursively else { - for ([[maybe_unused]] const auto i : c10::irange(self_sizes[dim])) { + for (ET_UNUSED const auto i : c10::irange(self_sizes[dim])) { apply_padding_to_dim( + ctx, ndim, self_data, self_sizes, @@ -95,7 +110,20 @@ void apply_padding_to_dim( } } - for ([[maybe_unused]] const auto i : c10::irange(pad_after)) { + // Do not copy padding beyond the out tensor bounds. + if (pad_after > 0) { + size_t numel = 1; + for (ET_UNUSED const auto i : c10::irange(out_sizes.size())) { + numel *= out_sizes[i]; + } + ET_KERNEL_CHECK_MSG( + ctx, + numel >= pad_after * out_step_len, + InvalidArgument, + /* void */, + "Out tensor is too small for the requested padding."); + } + for (ET_UNUSED const auto i : c10::irange(pad_after)) { set_all_to_value(out_data, out_step_len, value); out_data += out_step_len; } @@ -103,6 +131,7 @@ void apply_padding_to_dim( template void constant_pad_nd_out_impl( + KernelRuntimeContext& ctx, const Tensor& self, IntArrayRef pad, CTYPE value_v, @@ -145,6 +174,7 @@ void constant_pad_nd_out_impl( IntArrayRef out_strides_ref(out_strides, ndim); apply_padding_to_dim( + ctx, ndim, self_data, self_sizes_ref, @@ -192,7 +222,7 @@ Tensor& constant_pad_nd_out( utils::internal::check_overflow_scalar_cast(value); ET_KERNEL_CHECK(ctx, opt_value_casted.has_value(), InvalidArgument, ); auto value_casted = opt_value_casted.value(); - constant_pad_nd_out_impl(in, pad, value_casted, out); + constant_pad_nd_out_impl(ctx, in, pad, value_casted, out); }); return out; diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 968231fc42e..8164d1ebb02 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -49,7 +49,8 @@ Tensor& copy_out( // Use direct copy fast path if broadcast is not needed and tensors are // non-empty if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) && - src.numel() > 0) { + src.numel() > 0 && out.nbytes() >= src.nbytes() && + tensors_have_same_dtype(src, out)) { std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes()); } else { ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { @@ -91,8 +92,9 @@ Tensor& copy_( // Use direct copy fast path if broadcast is not needed and tensors are // non-empty if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) && - src.numel() > 0) { - std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes()); + src.numel() > 0 && in.nbytes() >= src.nbytes() && + tensors_have_same_dtype(src, in)) { + std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), src.nbytes()); } else { ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { utils::apply_bitensor_elementwise_fn< diff --git a/kernels/portable/cpu/op_index.cpp b/kernels/portable/cpu/op_index.cpp index 6ce9fb375de..d8eb992b85a 100644 --- a/kernels/portable/cpu/op_index.cpp +++ b/kernels/portable/cpu/op_index.cpp @@ -49,23 +49,6 @@ bool check_fast_path_conditions( if (index.dim() != 1) { return false; } - - // Fast path only supports non-negative indices. - if (ix_type == ScalarType::Int) { - const int32_t* const data = index.const_data_ptr(); - if (std::any_of(data, data + index.numel(), [](const auto x) { - return x < 0; - })) { - return false; - } - } else { // ScalarType::Long - const int64_t* const data = index.const_data_ptr(); - if (std::any_of(data, data + index.numel(), [](const auto x) { - return x < 0; - })) { - return false; - } - } } } @@ -96,8 +79,10 @@ bool check_fast_path_args( Long, Int, index.scalar_type(), ctx, "index.Tensor", CTYPE, [&]() { const CTYPE* const index_arr = index.const_data_ptr(); for (const auto i : c10::irange(index.numel())) { - if (index_arr[i] < 0 || - index_arr[i] >= static_cast(in.size(dim))) { + CTYPE index_val = index_arr[i]; + CTYPE dim_size = static_cast(in.size(dim)); + index_val = index_val < 0 ? index_val + dim_size : index_val; + if (index_val < 0 || index_val >= dim_size) { ET_LOG( Error, "Index %" PRId64 @@ -189,11 +174,14 @@ Tensor& fast_path( ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() { const CTYPE* const index_arr = index.const_data_ptr(); + CTYPE dim_size = static_cast(in.size(dim)); for (const auto i : c10::irange(leading_dims)) { const char* src = in_data + i * in_dim_length * length_per_step; char* dest = out_data + i * out_dim_length * length_per_step; for (const auto j : c10::irange(out_dim_length)) { - const char* copy_src = src + index_arr[j] * length_per_step; + auto index_val = + index_arr[j] < 0 ? index_arr[j] + dim_size : index_arr[j]; + const char* copy_src = src + index_val * length_per_step; char* copy_dest = dest + j * length_per_step; memcpy(copy_dest, copy_src, length_per_step); } diff --git a/kernels/portable/cpu/op_narrow_copy.cpp b/kernels/portable/cpu/op_narrow_copy.cpp index 960ea35efac..260ee9697db 100644 --- a/kernels/portable/cpu/op_narrow_copy.cpp +++ b/kernels/portable/cpu/op_narrow_copy.cpp @@ -46,7 +46,7 @@ Tensor& narrow_copy_out( out); if (length != 0) { - compute_slice(in, dim, start, length, 1, out); + compute_slice(ctx, in, dim, start, length, 1, out); } return out; diff --git a/kernels/portable/cpu/op_slice_copy.cpp b/kernels/portable/cpu/op_slice_copy.cpp index 1d4e509e083..0baacb874e1 100644 --- a/kernels/portable/cpu/op_slice_copy.cpp +++ b/kernels/portable/cpu/op_slice_copy.cpp @@ -55,7 +55,7 @@ Tensor& slice_copy_Tensor_out( InvalidArgument, out); - compute_slice(in, dim, start, length, step, out); + compute_slice(ctx, in, dim, start, length, step, out); return out; } diff --git a/kernels/portable/cpu/op_view_as_real_copy.cpp b/kernels/portable/cpu/op_view_as_real_copy.cpp index 4461ecb02f8..fe0ced2f722 100644 --- a/kernels/portable/cpu/op_view_as_real_copy.cpp +++ b/kernels/portable/cpu/op_view_as_real_copy.cpp @@ -41,6 +41,14 @@ Tensor& view_as_real_copy_out( // Get the output shape Tensor::SizesType expected_output_size[kTensorDimensionLimit]; + ET_KERNEL_CHECK_MSG( + ctx, + static_cast(self.dim()) < kTensorDimensionLimit, + InvalidArgument, + out, + "Output size buffer is too small. Expected at least %zu, got %zu", + self.dim() + 1, + kTensorDimensionLimit); get_view_as_real_copy_out_target_size(self, expected_output_size); // Resize for dynamic shape diff --git a/kernels/portable/cpu/util/padding_util.h b/kernels/portable/cpu/util/padding_util.h index 50cfcc65643..234c5ba5602 100644 --- a/kernels/portable/cpu/util/padding_util.h +++ b/kernels/portable/cpu/util/padding_util.h @@ -56,8 +56,9 @@ void pad1d( size_t out_i_base = i * out_width; size_t in_i_base = i * in_width; for (const auto w : c10::irange(out_width)) { - out_data[out_i_base + w] = - in_data[in_i_base + padding_ix(w, in_width, pad_left)]; + int64_t in_w_idx = padding_ix(w, in_width, pad_left); + ET_CHECK(in_w_idx >= 0 && in_w_idx < in_width); + out_data[out_i_base + w] = in_data[in_i_base + in_w_idx]; } } } @@ -85,11 +86,13 @@ void pad2d( size_t in_i_base = i * in_height * in_width; for (const auto h : c10::irange(out_height)) { size_t out_h_base = out_i_base + h * out_width; - size_t in_h_base = - in_i_base + padding_ix(h, in_height, pad_top) * in_width; + int64_t in_h_idx = padding_ix(h, in_height, pad_top); + ET_CHECK(in_h_idx >= 0 && in_h_idx < in_height); + size_t in_h_base = in_i_base + in_h_idx * in_width; for (const auto w : c10::irange(out_width)) { - out_data[out_h_base + w] = - in_data[in_h_base + padding_ix(w, in_width, pad_left)]; + int64_t in_w_idx = padding_ix(w, in_width, pad_left); + ET_CHECK(in_w_idx >= 0 && in_w_idx < in_width); + out_data[out_h_base + w] = in_data[in_h_base + in_w_idx]; } } } @@ -121,15 +124,18 @@ void pad3d( size_t in_i_base = i * in_depth * in_height * in_width; for (const auto d : c10::irange(out_depth)) { size_t out_d_base = out_i_base + d * out_height * out_width; - size_t in_d_base = - in_i_base + padding_ix(d, in_depth, pad_front) * in_height * in_width; + int64_t in_d_base_padding = padding_ix(d, in_depth, pad_front); + ET_CHECK(in_d_base_padding >= 0 && in_d_base_padding < in_depth); + size_t in_d_base = in_i_base + in_d_base_padding * in_height * in_width; for (const auto h : c10::irange(out_height)) { size_t out_h_base = out_d_base + h * out_width; - size_t in_h_base = - in_d_base + padding_ix(h, in_height, pad_top) * in_width; + int64_t in_h_base_padding = padding_ix(h, in_height, pad_top); + ET_CHECK(in_h_base_padding >= 0 && in_h_base_padding < in_height); + size_t in_h_base = in_d_base + in_h_base_padding * in_width; for (const auto w : c10::irange(out_width)) { - out_data[out_h_base + w] = - in_data[in_h_base + padding_ix(w, in_width, pad_left)]; + int64_t in_w_base_padding = padding_ix(w, in_width, pad_left); + ET_CHECK(in_w_base_padding >= 0 && in_w_base_padding < in_width); + out_data[out_h_base + w] = in_data[in_h_base + in_w_base_padding]; } } } diff --git a/kernels/portable/cpu/util/slice_util.cpp b/kernels/portable/cpu/util/slice_util.cpp index 909bd827d79..05e2f7d8289 100644 --- a/kernels/portable/cpu/util/slice_util.cpp +++ b/kernels/portable/cpu/util/slice_util.cpp @@ -81,7 +81,7 @@ bool check_slice_scatter_args( Tensor output) { ET_LOG_AND_RETURN_IF_FALSE(input.dim() > 0); - // Check dim. The dim planed to be selected on shall exist in input + // Check dim. The dim planned to be selected on shall exist in input ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, input.dim())); // Input and output tensors should be the same shape and dtype @@ -97,7 +97,7 @@ bool check_slice_scatter_args( // The size of src tensor should follow these rules: // - src.size(i) shall equal to input.size(i) if i != dim, // - src.size(dim) shall equal to num_values - for (const auto d : c10::irange(input.dim() - 1)) { + for (const auto d : c10::irange(input.dim())) { if (d != dim) { ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_size_at_dims(input, d, src, d)); @@ -150,13 +150,39 @@ int64_t adjust_slice_indices( } void compute_slice( + KernelRuntimeContext& ctx, const Tensor& in, int64_t dim, int64_t start, int64_t length, int64_t step, Tensor& out) { + // No slicing requested. + if (length <= 0) { + return; + } + + ET_KERNEL_CHECK_MSG( + ctx, + dim < in.dim(), + InvalidArgument, + /* void */, + "Requested dim is larger than input tensor dim"); size_t dim_length = in.size(dim); + ET_KERNEL_CHECK_MSG( + ctx, + start >= 0 && length >= 0 && step >= 0, + InvalidArgument, + /* void */, + "Input args should be >= 0."); + int64_t requested_slice = start + (length - 1) * step; + ET_KERNEL_CHECK_MSG( + ctx, + static_cast(requested_slice) < + static_cast(dim_length), + InvalidArgument, + /* void */, + "Requested slice is larger than the dim size"); size_t leading_dims = getLeadingDims(in, dim); size_t trailing_dims = getTrailingDims(in, dim); @@ -170,6 +196,12 @@ void compute_slice( const char* input_data = in.const_data_ptr(); char* dest = out.mutable_data_ptr(); + ET_KERNEL_CHECK_MSG( + ctx, + out.nbytes() >= (length * leading_dims * length_per_step), + InvalidArgument, + /* void */, + "out.nbytes() is smaller than the expected slice size."); for (const auto i : c10::irange(leading_dims)) { const char* src = input_data + (i * dim_length + start) * length_per_step; for ([[maybe_unused]] const auto j : c10::irange(length)) { diff --git a/kernels/portable/cpu/util/slice_util.h b/kernels/portable/cpu/util/slice_util.h index accfb387246..52a0b05c864 100644 --- a/kernels/portable/cpu/util/slice_util.h +++ b/kernels/portable/cpu/util/slice_util.h @@ -55,6 +55,7 @@ int64_t adjust_slice_indices( int64_t step); void compute_slice( + KernelRuntimeContext& ctx, const Tensor& in, int64_t dim, int64_t start, diff --git a/kernels/prim_ops/et_copy_index.cpp b/kernels/prim_ops/et_copy_index.cpp index dfcaf1eb550..2ef076ad1a0 100644 --- a/kernels/prim_ops/et_copy_index.cpp +++ b/kernels/prim_ops/et_copy_index.cpp @@ -59,7 +59,7 @@ constexpr size_t kTensorDimensionLimit = 16; // torch.ops.executorch.prim.add.int(iteration_index, 1, iteration_index) // done_bool = torch.ops.executorch.prim.eq.int(iteration_index, // sym_size, done_bool) # Emitter inserts a instruction here, if -// done_bool == False jump to selcect_copy op # if not continue. return +// done_bool == False jump to select_copy op # if not continue. return // add_tensor // // The output of each iteration (copy_from) is copied into the copy_to tensor at @@ -79,12 +79,24 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { auto copy_from = (*stack[1]).toTensor(); auto index = (*stack[2]).toInt(); + ET_KERNEL_CHECK_MSG( + context, + index >= 0, + InvalidArgument, + /* void */, + "Expected index to be non-negative."); + // Number of bytes we need to copy over from copy_from tensor. size_t size_copy_from = (copy_from.element_size()) * (copy_from.numel()); - ET_CHECK_MSG( + ET_KERNEL_CHECK_MSG( + context, (copy_to.sizes().size() - copy_from.sizes().size()) == 1, - "Ranks of copy_to and copy_from tensor should only differ by 1."); + InvalidArgument, + /* void */, + "Ranks of copy_to %zu and copy_from tensor %zu should only differ by 1.", + copy_to.sizes().size(), + copy_from.sizes().size()); // Here we calculate the size of the out_tensor after copy_from has // been copied to it. This will be passed onto the resize call. @@ -93,8 +105,11 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { // If we're copying past the first index then the shape of // copy_from and copy_to without the leading dimension should be // the same. i.e. copy_to.size[1:] == copy_from.size[:]. - ET_CHECK_MSG( + ET_KERNEL_CHECK_MSG( + context, copy_to.sizes()[i + 1] == copy_from.sizes()[i], + InvalidArgument, + /* void */, "Mismatch in shape between copy_to and copy_from tensors"); expected_output_size[i + 1] = copy_from.sizes()[i]; } @@ -105,11 +120,22 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { Error err = resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); ET_CHECK(err == Error::Ok); - ET_CHECK_MSG( + ET_KERNEL_CHECK_MSG( + context, data_ptr == copy_to.const_data_ptr(), + InvalidState, + /* void */, "Data ptr of copy_to tensor changed after resize which isn't allowed for static/upper-bounded tensors"); } + // After potential resize, verify that index is within bounds. + ET_KERNEL_CHECK_MSG( + context, + index < copy_to.sizes()[0], + InvalidArgument, + /* void */, + "Index out of bounds"); + auto copy_to_ptr = copy_to.const_data_ptr(); auto copy_from_ptr = copy_from.const_data_ptr(); @@ -118,12 +144,22 @@ void et_copy_index(KernelRuntimeContext& context, Span stack) { // copy_from into the copy_to tensor. // Check that the destination has enough space for the copy. + ET_KERNEL_CHECK_MSG( + context, + size_copy_from == 0 || + static_cast(index) <= SIZE_MAX / size_copy_from, + InvalidArgument, + /* void */, + "Offset multiplication ."); size_t offset = index * size_copy_from; size_t copy_to_size = copy_to.element_size() * copy_to.numel(); - ET_CHECK_MSG( - offset + size_copy_from <= copy_to_size, - "Buffer overflow: copy_to tensor is smaller than copy_from tensor."); - + ET_KERNEL_CHECK_MSG( + context, + (offset <= SIZE_MAX - size_copy_from) && + (offset + size_copy_from <= copy_to_size), + InvalidArgument, + /* void */, + "Buffer overflow; offset overflow or copy_to tensor is smaller than copy_from tensor."); memcpy( // NOLINTNEXTLINE(performance-no-int-to-ptr) (void*)((uintptr_t)copy_to_ptr + offset), diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index dc6ed9ac26f..7ff2f1f868d 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -141,6 +141,12 @@ static Kernel prim_ops[] = { EValue& out = *stack[1]; executorch::aten::Tensor self_tensor = self.to(); + ET_KERNEL_CHECK_MSG( + context, + self_tensor.numel() >= 1, + InvalidArgument, + /* void */, + "Expected tensor with at least 1 element"); ET_SWITCH_REAL_TYPES_AND( Bool, self_tensor.scalar_type(), diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 1ccb2c27ce5..b46733045fb 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -276,8 +276,9 @@ TEST_F(RegisterPrimOpsTest, TestETCopyIndexMismatchShape) { // Try to copy and replace at index 1. This will fail because // copy_to.sizes[1:] and to_copy.sizes[:] don't match each other // which is a pre-requisite for this operator. - ET_EXPECT_DEATH( - getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack), ""); + ET_EXPECT_KERNEL_FAILURE( + context_, + getOpsFn("executorch_prim::et_copy_index.tensor")(context_, stack)); } TEST_F(RegisterPrimOpsTest, TestETCopyIndexStaticShape) { diff --git a/kernels/quantized/cpu/op_choose_qparams.cpp b/kernels/quantized/cpu/op_choose_qparams.cpp index 5335f4bfbd2..acb8e100af6 100644 --- a/kernels/quantized/cpu/op_choose_qparams.cpp +++ b/kernels/quantized/cpu/op_choose_qparams.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -202,17 +203,42 @@ void choose_qparams_per_token( num_tokens *= input.size(i); } auto token_dim_size = input.size(input.dim() - 1); - for (auto i = 0; i < num_tokens; i++) { - // vec_minf uses std::min_element. Check if it actually - // gets vectorized. - float min = torch::executor::vec_minf(x_fp32, token_dim_size); - float max = torch::executor::vec_maxf(x_fp32, token_dim_size); - double scale; - int32_t zero_point; - calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point); - scale_out.mutable_data_ptr()[i] = scale; - zero_point_out.mutable_data_ptr()[i] = zero_point; - x_fp32 += token_dim_size; + + const int64_t total_elements = num_tokens * token_dim_size; + constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512; + const bool use_parallel = total_elements >= MIN_ELEMENTS_FOR_PARALLEL; + + if (use_parallel) { + auto* scale_data = scale_out.mutable_data_ptr(); + auto* zero_point_data = zero_point_out.mutable_data_ptr(); + + ::executorch::extension::parallel_for( + 0, num_tokens, 1, [&](const int64_t begin, const int64_t end) { + for (int64_t i = begin; i < end; i++) { + const float* token_data = x_fp32 + i * token_dim_size; + float min = torch::executor::vec_minf(token_data, token_dim_size); + float max = torch::executor::vec_maxf(token_data, token_dim_size); + double scale; + int32_t zero_point; + calculate_scale_and_zero_point( + min, max, qmin, qmax, scale, zero_point); + scale_data[i] = scale; + zero_point_data[i] = zero_point; + } + }); + } else { + for (auto i = 0; i < num_tokens; i++) { + // vec_minf uses std::min_element. Check if it actually + // gets vectorized. + float min = torch::executor::vec_minf(x_fp32, token_dim_size); + float max = torch::executor::vec_maxf(x_fp32, token_dim_size); + double scale; + int32_t zero_point; + calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point); + scale_out.mutable_data_ptr()[i] = scale; + zero_point_out.mutable_data_ptr()[i] = zero_point; + x_fp32 += token_dim_size; + } } } } // namespace diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 5586f8a77eb..e52b9a371e6 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -7,10 +7,15 @@ */ #include +#include #include #include #include +#if defined(__aarch64__) || defined(__ARM_NEON__) +#include +#endif + /** * For an input tensor, use the scale and zero_point arguments to quantize it. */ @@ -105,6 +110,143 @@ T quantize_val( return static_cast(qvalue); } +#if defined(__aarch64__) || defined(__ARM_NEON__) + +// Traits for type-specific NEON operations +template +struct NeonQuantizeTraits; + +template <> +struct NeonQuantizeTraits { + // Narrow int16x8 to uint8x8 with saturation (unsigned) + static inline uint8x8_t narrow_and_saturate(int16x8_t v) { + return vqmovun_s16(v); + } + + // Store uint8x8 to memory + static inline void store(uint8_t* ptr, uint8x8_t v) { + vst1_u8(ptr, v); + } + + // Scalar clamping for uint8 + static inline uint8_t clamp_scalar(int32_t val) { + return static_cast(std::min(255, std::max(0, val))); + } +}; + +template <> +struct NeonQuantizeTraits { + // Narrow int16x8 to int8x8 with saturation (signed) + static inline int8x8_t narrow_and_saturate(int16x8_t v) { + return vqmovn_s16(v); + } + + // Store int8x8 to memory + static inline void store(int8_t* ptr, int8x8_t v) { + vst1_s8(ptr, v); + } + + // Scalar clamping for int8 + static inline int8_t clamp_scalar(int32_t val) { + return static_cast(std::min(127, std::max(-128, val))); + } +}; + +// Unified ARM NEON optimized quantization for contiguous blocks +// Processes N elements with a single scale/zero_point pair +// Used for both per-tensor (entire tensor) and per-channel (one block per +// channel) +template +void quantize_arm( + const float* __restrict__ in, + T* __restrict__ out, + const int64_t N, + const float inv_scale, + const int32_t zero_point, + const int32_t quant_min, + const int32_t quant_max) { + using Traits = NeonQuantizeTraits; + const float32x4_t vinv_scale = vdupq_n_f32(inv_scale); + +#if defined(__aarch64__) + // ARMv8: Use vcvtnq_s32_f32 for rounding + const int16x8_t vzero_point = vdupq_n_s16(static_cast(zero_point)); + const int16x8_t vquant_min = vdupq_n_s16(static_cast(quant_min)); + const int16x8_t vquant_max = vdupq_n_s16(static_cast(quant_max)); + + int64_t i = 0; + // Process 8 elements at a time + for (; i + 8 <= N; i += 8) { + const float32x4_t vin0123 = vld1q_f32(in + i); + const float32x4_t vin4567 = vld1q_f32(in + i + 4); + + // Multiply by inv_scale and round + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale)); + + // Combine to int16 and add zero_point + int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point); + + // Clamp to quant_min/quant_max + v01234567_packed = vmaxq_s16(v01234567_packed, vquant_min); + v01234567_packed = vminq_s16(v01234567_packed, vquant_max); + + // Convert to T (int8/uint8) with saturation using type-specific operation + const auto vout01234567 = Traits::narrow_and_saturate(v01234567_packed); + Traits::store(out + i, vout01234567); + } + + // Handle remaining elements with proper quant_min/quant_max clamping + for (; i < N; ++i) { + float val = in[i] * inv_scale; + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::max(quant_min, std::min(quant_max, qval)); + out[i] = static_cast(qval); + } + +#else + // ARMv7: Use magic float rounding + const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000); + const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f); + + int64_t i = 0; + // Process 8 elements at a time + for (; i + 8 <= N; i += 8) { + const float32x4_t vin0123 = vld1q_f32(in + i); + const float32x4_t vin4567 = vld1q_f32(in + i + 4); + + const int32x4_t vraw0123 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale)))); + + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + + // Convert to T (int8/uint8) with saturation using type-specific operation + const auto vout01234567 = Traits::narrow_and_saturate(vraw01234567); + Traits::store(out + i, vout01234567); + } + + // Handle remaining elements with proper quant_min/quant_max clamping + for (; i < N; ++i) { + float val = in[i] * inv_scale; + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::max(quant_min, std::min(quant_max, qval)); + out[i] = static_cast(qval); + } +#endif +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON__) + Tensor& quantize_per_tensor_out( const Tensor& input, double scale, @@ -120,19 +262,44 @@ Tensor& quantize_per_tensor_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // calculate the quantized input -#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - IN_CTYPE value = input_data_ptr[i]; \ - out_data_ptr[i] = quantize_val( \ - scale, zero_point, value, quant_min, quant_max); \ - } \ + // Try ARM NEON optimized path for float->int8/uint8 quantization +#if defined(__aarch64__) || defined(__ARM_NEON__) + if (input.scalar_type() == ScalarType::Float) { + if (dtype == ScalarType::Byte) { + quantize_arm( + input.const_data_ptr(), + out.mutable_data_ptr(), + input.numel(), + 1.0f / static_cast(scale), + static_cast(zero_point), + static_cast(quant_min), + static_cast(quant_max)); + return out; + } else if (dtype == ScalarType::Char) { + quantize_arm( + input.const_data_ptr(), + out.mutable_data_ptr(), + input.numel(), + 1.0f / static_cast(scale), + static_cast(zero_point), + static_cast(quant_min), + static_cast(quant_max)); + return out; + } + } +#endif + + // Fallback scalar implementation for all other cases +#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + scale, zero_point, value, quant_min, quant_max); \ + } \ } break; #define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype) \ case ScalarType::in_dtype: \ @@ -284,29 +451,138 @@ Tensor& quantize_per_channel_out( const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - // High-performance single loop with direct channel calculation -#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: { \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const int64_t input_numel = input.numel(); \ - const int64_t axis_size = input.size(axis); \ - /* Calculate the stride pattern for efficient channel index calculation */ \ - int64_t axis_block_size = 1; \ - for (int64_t i = axis + 1; i < input.dim(); i++) { \ - axis_block_size *= input.size(i); \ - } \ - /* Single loop over all elements */ \ - for (int64_t i = 0; i < input_numel; i++) { \ - /* Calculate which channel this element belongs to */ \ - int64_t channel_idx = (i / axis_block_size) % axis_size; \ - /* Get quantization parameters for this channel */ \ - double _scale = scale_data[channel_idx]; \ - int64_t _zero_point = zero_point_data[channel_idx]; \ - /* Apply quantization */ \ - out_data_ptr[i] = quantize_val( \ - _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ - } \ + // Calculate the block size for each channel + int64_t axis_block_size = 1; + for (int64_t i = axis + 1; i < input.dim(); i++) { + axis_block_size *= input.size(i); + } + const int64_t axis_size = input.size(axis); + + // Try ARM NEON optimized path for float->int8/uint8 quantization +#if defined(__aarch64__) || defined(__ARM_NEON__) + if (input.scalar_type() == ScalarType::Float) { + const int64_t num_blocks = input.numel() / axis_block_size; + const int64_t total_elements = input.numel(); + constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512; + const bool use_parallel = (total_elements >= MIN_ELEMENTS_FOR_PARALLEL); + + if (dtype == ScalarType::Byte) { + auto* out_data_ptr = out.mutable_data_ptr(); + const auto* input_data_ptr = input.const_data_ptr(); + + if (use_parallel) { + ::executorch::extension::parallel_for( + 0, num_blocks, 1, [&](const int64_t begin, const int64_t end) { + for (int64_t block = begin; block < end; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = + 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + }); + } else { + // Process each contiguous block (which shares the same + // scale/zero_point) + for (int64_t block = 0; block < num_blocks; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + } + return out; + } else if (dtype == ScalarType::Char) { + auto* out_data_ptr = out.mutable_data_ptr(); + const auto* input_data_ptr = input.const_data_ptr(); + + if (use_parallel) { + ::executorch::extension::parallel_for( + 0, num_blocks, 1, [&](const int64_t begin, const int64_t end) { + for (int64_t block = begin; block < end; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = + 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + int8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + }); + } else { + // Process each contiguous block (which shares the same + // scale/zero_point) + for (int64_t block = 0; block < num_blocks; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + int8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + } + return out; + } + } +#endif + + // Fallback scalar implementation +#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ + } \ } break; #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index f29f1f013b7..88a3823c5f3 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -9,6 +9,7 @@ _QUANT_OPS = ( name = "op_choose_qparams", deps = [ "//executorch/kernels/portable/cpu:vec_ops", + "//executorch/extension/threadpool:threadpool", ], ), op_target( @@ -51,6 +52,9 @@ _QUANT_OPS = ( ), op_target( name = "op_quantize", + deps = [ + "//executorch/extension/threadpool:threadpool", + ], ), ) diff --git a/kernels/quantized/test/op_choose_qparams_test.cpp b/kernels/quantized/test/op_choose_qparams_test.cpp index 13426bfdd86..dc92df80488 100644 --- a/kernels/quantized/test/op_choose_qparams_test.cpp +++ b/kernels/quantized/test/op_choose_qparams_test.cpp @@ -15,6 +15,7 @@ #include #include +#include #include using namespace ::testing; @@ -163,3 +164,97 @@ TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, DynamicShapeFloat) { EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, new_expected_scale, 1e-4, 1e-4); EXPECT_TENSOR_EQ(zero_point_out, new_expected_zero_point); } + +TEST( + OpChooseQparamsPerTokenAsymmetricTensorOutTest, + LargeInputParallelization) { + et_pal_init(); + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Create input with 8 tokens x 128 elements per token = 1024 total elements + // This exceeds the MIN_ELEMENTS_FOR_PARALLEL threshold of 512 + const int num_tokens = 8; + const int token_size = 128; + std::vector input_data(num_tokens * token_size); + + // Generate test data with known min/max per token for easier verification + std::vector expected_min(num_tokens); + std::vector expected_max(num_tokens); + + for (int i = 0; i < num_tokens; i++) { + float token_min = -1.0f * (i + 1); + float token_max = 2.0f * (i + 1); + expected_min[i] = token_min; + expected_max[i] = token_max; + + for (int j = 0; j < token_size; j++) { + // Linearly interpolate between min and max + float t = j / static_cast(token_size - 1); + input_data[i * token_size + j] = token_min + t * (token_max - token_min); + } + } + + Tensor input = tf_float.make({num_tokens, token_size}, input_data); + Tensor scale_out = tf_double.zeros({num_tokens, 1}); + Tensor zero_point_out = tf_long.zeros({num_tokens, 1}); + + choose_qparams_per_token_asymmetric_out( + input, ScalarType::Float, scale_out, zero_point_out); + + // Manually calculate expected scale and zero_point using the same algorithm + // as calculate_scale_and_zero_point function + const int32_t qmin = -128; + const int32_t qmax = 127; + const float SMALL_SCALE_THRESHOLD = 6.1e-5f; + + for (int i = 0; i < num_tokens; i++) { + float min = std::min(expected_min[i], 0.0f); + float max = std::max(expected_max[i], 0.0f); + + // Calculate scale + double scale = (static_cast(max) - min) / (qmax - qmin); + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + // Cut off small scale + if (scale < SMALL_SCALE_THRESHOLD) { + scale = SMALL_SCALE_THRESHOLD; + if (min == 0.0f) { + max = SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else if (max == 0.0f) { + min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / scale; + min *= amplifier; + max *= amplifier; + } + } + + // Calculate zero_point + double zero_point_from_min = qmin - min / scale; + double zero_point_from_max = qmax - max / scale; + double zero_point_from_min_error = std::abs(qmin) - std::abs(min / scale); + double zero_point_from_max_error = std::abs(qmax) - std::abs(max / scale); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + int32_t nudged_zero_point = 0; + if (initial_zero_point < qmin) { + nudged_zero_point = qmin; + } else if (initial_zero_point > qmax) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = + std::nearbyint(static_cast(initial_zero_point)); + } + + // Verify computed values match expected + EXPECT_NEAR(scale_out.const_data_ptr()[i], scale, 1e-6); + EXPECT_EQ(zero_point_out.const_data_ptr()[i], nudged_zero_point); + } +} diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 4ac835c24ce..b450ec0ee33 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -14,7 +14,6 @@ #include #include -#include using namespace ::testing; using executorch::aten::ArrayRef; @@ -446,3 +445,539 @@ TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, LargePerChannelClampingSIMDPath) { + // Test quant_min/quant_max clamping with large tensor to exercise SIMD path + // Shape: [3, 80] with axis=0 (3 channels, 80 elements each) + // 80 elements = 10 SIMD iterations (8 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 3; + const int block_size = 80; + std::vector input_data(num_channels * block_size); + + // Create input data with values that exceed quant_min/quant_max + for (int ch = 0; ch < num_channels; ch++) { + for (int i = 0; i < block_size; i++) { + // Generate values from -150 to 150 to test clamping + input_data[ch * block_size + i] = + static_cast((i % 40) - 20) * 5.0f * (ch + 1); + } + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + // Use uniform scale and zero_point for all channels + Tensor scale = tf_double.make({num_channels}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({num_channels}, {0, 0, 0}); + + // Set narrow quant_min/quant_max to force clamping + int64_t quant_min = -20; + int64_t quant_max = 20; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values with clamping + std::vector expected_data(num_channels * block_size); + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + // Use double precision to avoid overflow + double val = static_cast(input_data[idx]) / ch_scale; + // Clamp before converting to int to avoid overflow + val = std::max(-1000.0, std::min(1000.0, val)); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + // Apply quant_min/quant_max clamping + qval = std::max( + static_cast(quant_min), + std::min(static_cast(quant_max), qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +// Large tensor tests to ensure ARM NEON SIMD path is exercised + +TEST(OpQuantizeOutTest, LargeTensorUInt8SIMDPath) { + // Test with 64 elements to fully exercise SIMD path (8 elements per + // iteration) + TensorFactory tf_float; + + // Create input with known values for verification + std::vector input_data(64); + for (size_t i = 0; i < 64; i++) { + input_data[i] = static_cast(i) * 0.5f; // 0.0, 0.5, 1.0, 1.5, ... + } + Tensor input = tf_float.make({64}, input_data); + + double scale = 0.1; + int64_t zero_point = 10; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({64}); + + // Compute expected values: round(value / scale) + zero_point + std::vector expected_data(64); + for (size_t i = 0; i < 64; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({64}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargeTensorInt8SIMDPath) { + // Test with 72 elements (9 SIMD iterations of 8) to test both vectorized and + // scalar paths + TensorFactory tf_float; + + std::vector input_data(72); + for (size_t i = 0; i < 72; i++) { + // Mix of positive and negative values + input_data[i] = static_cast(static_cast(i) - 36) * 0.25f; + } + Tensor input = tf_float.make({72}, input_data); + + double scale = 0.2; + int64_t zero_point = 0; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({72}); + + // Compute expected values + std::vector expected_data(72); + for (size_t i = 0; i < 72; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(127, std::max(-128, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({72}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargeTensorWithRemainderUInt8) { + // Test with 100 elements (12 SIMD iterations + 4 remainder) to test remainder + // handling + TensorFactory tf_float; + + std::vector input_data(100); + for (size_t i = 0; i < 100; i++) { + input_data[i] = static_cast(i % 50) * 0.3f; + } + Tensor input = tf_float.make({100}, input_data); + + double scale = 0.15; + int64_t zero_point = 128; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({100}); + + std::vector expected_data(100); + for (size_t i = 0; i < 100; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({100}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargeTensorWithRemainderInt8) { + // Test with 99 elements (12 SIMD iterations + 3 remainder) + TensorFactory tf_float; + + std::vector input_data(99); + for (size_t i = 0; i < 99; i++) { + input_data[i] = std::sin(static_cast(i) * 0.1f) * 10.0f; + } + Tensor input = tf_float.make({99}, input_data); + + double scale = 0.1; + int64_t zero_point = 5; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({99}); + + std::vector expected_data(99); + for (size_t i = 0; i < 99; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(127, std::max(-128, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({99}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, VeryLargeTensor2DUInt8) { + // Test with realistic 2D tensor size that would be used in neural networks + // 256x256 = 65536 elements (8192 SIMD iterations) + TensorFactory tf_float; + + std::vector input_data(256 * 256); + for (size_t i = 0; i < 256 * 256; i++) { + // Generate diverse values in a safe range + input_data[i] = + static_cast((static_cast(i % 256) - 128)) * 0.05f; + } + Tensor input = tf_float.make({256, 256}, input_data); + + double scale = 0.05; + int64_t zero_point = 128; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({256, 256}); + + // Compute expected values with proper overflow handling + std::vector expected_data(256 * 256); + for (size_t i = 0; i < 256 * 256; i++) { + // Use double precision to avoid overflow + double val = static_cast(input_data[i]) / scale; + // Clamp before converting to int to avoid overflow + val = std::max(-1000.0, std::min(1000.0, val)); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(zero_point); + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({256, 256}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, VeryLargeTensor3DInt8) { + // Test with 3D tensor (batch_size=2, height=64, width=128) = 16384 elements + TensorFactory tf_float; + + const size_t total_elements = 2 * 64 * 128; + std::vector input_data(total_elements); + for (size_t i = 0; i < total_elements; i++) { + input_data[i] = std::cos(static_cast(i) * 0.01f) * 8.0f; + } + Tensor input = tf_float.make({2, 64, 128}, input_data); + + double scale = 0.0625; // 1/16 + int64_t zero_point = -10; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 64, 128}); + + std::vector expected_data(total_elements); + for (size_t i = 0; i < total_elements; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(127, std::max(-128, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({2, 64, 128}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, EdgeCaseSizesSIMD) { + // Test specific sizes around SIMD boundaries + TensorFactory tf_float; + TensorFactory tfo; + + double scale = 0.1; + int64_t zero_point = 100; + int64_t quant_min = 0; + int64_t quant_max = 255; + + // Test sizes: 7 (just before SIMD), 8 (exactly 1 SIMD), 9 (1 SIMD + 1), 15, + // 16, 17 + std::vector test_sizes = { + 7, 8, 9, 15, 16, 17, 23, 24, 25, 31, 32, 33}; + + for (size_t size : test_sizes) { + std::vector input_data(size); + std::vector expected_data(size); + + for (size_t i = 0; i < size; i++) { + input_data[i] = static_cast(i) * 0.3f; + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + + Tensor input = tf_float.make({static_cast(size)}, input_data); + Tensor out = tfo.zeros({static_cast(size)}); + Tensor expected = tfo.make({static_cast(size)}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); + } +} + +// Large tensor tests for per-channel quantization to ensure SIMD path is +// exercised + +TEST(OpQuantizeOutTest, LargePerChannelUInt8SIMDPath) { + // Test per-channel quantization with large blocks (64 elements per channel) + // Shape: [4, 64] with axis=1 (4 channels, 64 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 4; + const int block_size = 64; + std::vector input_data(num_channels * block_size); + + // Create varying input data for each channel + for (int ch = 0; ch < num_channels; ch++) { + for (int i = 0; i < block_size; i++) { + input_data[ch * block_size + i] = static_cast((ch + 1) * i) * 0.1f; + } + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + // Different scale and zero_point for each channel + Tensor scale = tf_double.make({num_channels}, {0.1, 0.2, 0.15, 0.25}); + Tensor zero_point = tf_long.make({num_channels}, {10, 20, 15, 25}); + + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values + std::vector expected_data(num_channels * block_size); + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + float val = input_data[idx] / static_cast(ch_scale); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + qval = std::min(255, std::max(0, qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargePerChannelInt8SIMDPath) { + // Test per-channel quantization with int8 and large blocks + // Shape: [3, 100] with axis=1 (3 channels, 100 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 3; + const int block_size = 100; // 12 SIMD iterations + 4 remainder + std::vector input_data(num_channels * block_size); + + // Create varying input data with negative values + for (int ch = 0; ch < num_channels; ch++) { + for (int i = 0; i < block_size; i++) { + input_data[ch * block_size + i] = + static_cast(i - 50) * 0.2f * (ch + 1); + } + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + Tensor scale = tf_double.make({num_channels}, {0.1, 0.15, 0.2}); + Tensor zero_point = tf_long.make({num_channels}, {0, -5, 5}); + + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values + std::vector expected_data(num_channels * block_size); + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + float val = input_data[idx] / static_cast(ch_scale); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + qval = std::min(127, std::max(-128, qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, VeryLargePerChannel2DUInt8) { + // Test realistic neural network weight tensor + // Shape: [128, 256] with axis=0 (128 channels, 256 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 128; + const int block_size = 256; + const int total_elements = num_channels * block_size; + + std::vector input_data(total_elements); + for (int i = 0; i < total_elements; i++) { + input_data[i] = std::sin(static_cast(i) * 0.01f) * 5.0f; + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + // Create varying scales and zero_points for each channel + std::vector scales(num_channels); + std::vector zero_points(num_channels); + for (int ch = 0; ch < num_channels; ch++) { + scales[ch] = 0.02 + (ch % 10) * 0.001; // Varying scales + zero_points[ch] = 128 + (ch % 5); // Varying zero_points + } + Tensor scale = tf_double.make({num_channels}, scales); + Tensor zero_point = tf_long.make({num_channels}, zero_points); + + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values + std::vector expected_data(total_elements); + for (int ch = 0; ch < num_channels; ch++) { + float inv_scale = 1.0f / static_cast(scales[ch]); + int64_t ch_zero_point = zero_points[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + float val = input_data[idx] * inv_scale; + // Clamp before converting to avoid overflow + val = std::max(-1000.0f, std::min(1000.0f, val)); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + + qval = std::min(255, std::max(0, qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, PerChannelAxis1LargeBlocks) { + // Test per-channel quantization with axis=1 and large contiguous blocks + // Shape: [2, 3, 64] with axis=1 (2 batches, 3 channels, 64 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int batch_size = 2; + const int num_channels = 3; + const int block_size = 64; + const int total_elements = batch_size * num_channels * block_size; + + std::vector input_data(total_elements); + for (int i = 0; i < total_elements; i++) { + input_data[i] = static_cast(i % 100) * 0.1f; + } + Tensor input = + tf_float.make({batch_size, num_channels, block_size}, input_data); + + Tensor scale = tf_double.make({num_channels}, {0.05, 0.1, 0.15}); + Tensor zero_point = tf_long.make({num_channels}, {100, 110, 120}); + + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({batch_size, num_channels, block_size}); + + // Compute expected values + std::vector expected_data(total_elements); + for (int b = 0; b < batch_size; b++) { + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = (b * num_channels + ch) * block_size + i; + float val = input_data[idx] / static_cast(ch_scale); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + qval = std::min(255, std::max(0, qval)); + expected_data[idx] = static_cast(qval); + } + } + } + Tensor expected = + tfo.make({batch_size, num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} diff --git a/kernels/test/op_index_test.cpp b/kernels/test/op_index_test.cpp index 787eb4612d8..8816d0a8d3f 100644 --- a/kernels/test/op_index_test.cpp +++ b/kernels/test/op_index_test.cpp @@ -947,3 +947,56 @@ TEST_F(OpIndexTensorOutTest, FastPathEmptyInput) { EXPECT_TENSOR_EQ(out, expected); } + +TEST_F(OpIndexTensorOutTest, FastPathNegativeIndex) { + TensorFactory tf; + TensorFactory tfl; + + // clang-format off + Tensor x = tf.make( + {2, 3, 4}, + { + // [0, :, :] + 1., 2., 3., 4., // [0, 0, :] + 5., 6., 7., 8., // [0, 1, :] + 9., 10., 11., 12., // [0, 2, :] + + // [1, :, :] + -1., -2., -3., -4., // [1, 0, :] + -5., -6., -7., -8., // [1, 1, :] + -9., -10., -11., -12., // [1, 2, :] + }); + // clang-format on + + // Use negative indices in the first dimension: -1, 0, -2 + std::array, 3> indices = { + optional(tfl.make({3}, {-1, 0, -2})), + optional(), + optional()}; + + Tensor out = tf.zeros({3, 3, 4}); + // clang-format off + Tensor expected = tf.make( + {3, 3, 4}, + { + // [1, :, :] + -1., -2., -3., -4., // [1, 0, :] + -5., -6., -7., -8., // [1, 1, :] + -9., -10., -11., -12., // [1, 2, :] + + // [0, :, :] + 1., 2., 3., 4., // [0, 0, :] + 5., 6., 7., 8., // [0, 1, :] + 9., 10., 11., 12., // [0, 2, :] + + // [0, :, :] again (since -2 wraps to 0) + 1., 2., 3., 4., // [0, 0, :] + 5., 6., 7., 8., // [0, 1, :] + 9., 10., 11., 12., // [0, 2, :] + }); + // clang-format on + + op_index_tensor_out(x, indices, out); + + EXPECT_TENSOR_EQ(out, expected); +} diff --git a/pyproject.toml b/pyproject.toml index cf42f3a1ea4..79b442aa37b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ readme = "README-wheel.md" authors = [ {name="PyTorch Team", email="packages@pytorch.org"}, ] -license = "BSD-3-Clause" -license-files = ["LICENSE"] +license = {text = "BSD-3-Clause"} + keywords = ["pytorch", "machine learning"] # PyPI package information. classifiers = [ @@ -62,7 +62,7 @@ dependencies=[ "packaging", "pandas>=2.2.2; python_version >= '3.10'", "parameterized", - "pytest", + "pytest<9.0", "pytest-xdist", "pytest-rerunfailures==15.1", "pytest-json-report", @@ -74,7 +74,7 @@ dependencies=[ # See also third-party/TARGETS for buck's typing-extensions version. "typing-extensions>=4.10.0", # Keep this version in sync with: ./backends/apple/coreml/scripts/install_requirements.sh - "coremltools==9.0b1; platform_system == 'Darwin' or platform_system == 'Linux'", + "coremltools==9.0; platform_system == 'Darwin' or platform_system == 'Linux'", # scikit-learn is used to support palettization in the coreml backend "scikit-learn==1.7.1", "hydra-core>=1.3.0", @@ -97,6 +97,9 @@ flatc = "executorch.data.bin:flatc" # TODO(dbort): Could use py_modules to restrict the set of modules we # package, and package_data to restrict the set up non-python files we # include. See also setuptools/discovery.py for custom finders. +[tool.setuptools] +license-files = ["LICENSE"] + [tool.setuptools.package-dir] # Tell setuptools to follow the symlink: src/executorch/* -> * for all first level # modules such as src/executorch/exir -> exir. This helps us to semi-compliant with diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index 9b490da244c..26b97e5a7a2 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -1212,6 +1212,12 @@ ET_NODISCARD inline Error resize_tensor( std::array new_sizes_casted{}; size_t new_sizes_ndim = new_sizes.size(); + ET_CHECK_OR_RETURN_ERROR( + new_sizes_ndim <= kTensorDimensionLimit, + InvalidArgument, + "new_sizes_ndim %zu is greater than kTensorDimensionLimit %zu", + new_sizes_ndim, + kTensorDimensionLimit); for (const auto i : c10::irange(new_sizes_ndim)) { new_sizes_casted[i] = static_cast(new_sizes[i]); diff --git a/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h b/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h index e340e7626a0..7c46eda0912 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +++ b/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h @@ -611,4 +611,60 @@ __host__ __device__ #define C10_RETURN_MOVE_IF_OLD_COMPILER 0 #endif +// The HIDDEN_NAMESPACE_BEGIN and HIDDEN_NAMESPACE_END below +// are needed for maintaining robustness in our header APIs in +// torch/headeronly and torch/csrc/stable under the namespaces +// torch::headeronly and torch::stable respectively. We enforce +// hidden visibility for these APIs because we want to enable +// loading custom extensions compiled against different libtorch +// versions where these APIs may have changed. + +// Helper macros for nested namespace expansion +#define _EXPAND(...) __VA_ARGS__ + +// Macros to handle 1-3 hidden namespace levels when not windows +#define _HIDDEN_NS_GET_MACRO(_1, _2, _3, NAME, ...) NAME +#define _HIDDEN_NS_1(n1) namespace n1 __attribute__((visibility("hidden"))) { +#define _HIDDEN_NS_2(n1, n2) \ + namespace n1 { \ + namespace n2 __attribute__((visibility("hidden"))) { +#define _HIDDEN_NS_3(n1, n2, n3) \ + namespace n1::n2 { \ + namespace n3 __attribute__((visibility("hidden"))) { + +// Macros to close namespaces when not windows +#define _HIDDEN_NS_END_1(n1) } +#define _HIDDEN_NS_END_N(n1, ...) \ + } \ + } + +// Macros to join strs with :: (for win, where symbols are hidden by default) +#define _JOIN_GET_MACRO(_1, _2, _3, NAME, ...) NAME +#define _JOIN_NS1(a) a +#define _JOIN_NS2(a, b) a::b +#define _JOIN_NS3(a, b, c) a::b::c + +#if !defined(HIDDEN_NAMESPACE_BEGIN) +#if defined(__GNUG__) && !defined(_WIN32) +#define HIDDEN_NAMESPACE_BEGIN(...) \ + _EXPAND(_HIDDEN_NS_GET_MACRO( \ + __VA_ARGS__, _HIDDEN_NS_3, _HIDDEN_NS_2, _HIDDEN_NS_1)(__VA_ARGS__)) +#else +#define HIDDEN_NAMESPACE_BEGIN(...) \ + namespace _EXPAND(_JOIN_GET_MACRO( \ + __VA_ARGS__, _JOIN_NS3, _JOIN_NS2, _JOIN_NS1)(__VA_ARGS__)) { +#endif +#endif + +#if !defined(HIDDEN_NAMESPACE_END) +#if defined(__GNUG__) && !defined(_WIN32) +#define HIDDEN_NAMESPACE_END(...) \ + _EXPAND(_HIDDEN_NS_GET_MACRO( \ + __VA_ARGS__, _HIDDEN_NS_END_N, _HIDDEN_NS_END_N, _HIDDEN_NS_END_1)( \ + __VA_ARGS__)) +#else +#define HIDDEN_NAMESPACE_END(...) } +#endif +#endif + #endif // C10_MACROS_MACROS_H_ diff --git a/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h b/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h index ac47e3f844a..64479ba36f1 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +++ b/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h @@ -395,7 +395,7 @@ inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { C10_CLANG_DIAGNOSTIC_POP() } // namespace c10 -namespace torch::headeronly { +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) namespace detail { using c10::detail::bits_from_f32; @@ -415,7 +415,7 @@ using c10::operator/=; using c10::operator<; using c10::operator>; using c10::operator<<; -} // namespace torch::headeronly +HIDDEN_NAMESPACE_END(torch, headeronly) namespace std { diff --git a/runtime/core/portable_type/c10/torch/headeronly/util/Half.h b/runtime/core/portable_type/c10/torch/headeronly/util/Half.h index 9673301e2de..a9c0b166ba2 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +++ b/runtime/core/portable_type/c10/torch/headeronly/util/Half.h @@ -698,7 +698,7 @@ C10_CLANG_DIAGNOSTIC_POP() } // namespace c10 -namespace torch::headeronly { +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) using c10::Half; using c10::operator+; @@ -724,7 +724,7 @@ using c10::detail::fp16_ieee_to_fp32_bits; using c10::detail::fp16_ieee_to_fp32_value; } // namespace detail -} // namespace torch::headeronly +HIDDEN_NAMESPACE_END(torch, headeronly) namespace std { diff --git a/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h b/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h index 561ea0467a0..f41269082d9 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +++ b/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h @@ -139,10 +139,10 @@ inline constexpr bool less_than_lowest(const T& x) { C10_CLANG_DIAGNOSTIC_POP() -namespace torch::headeronly { +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) using c10::greater_than_max; using c10::is_negative; using c10::less_than_lowest; using c10::signs_differ; using c10::signum; -} // namespace torch::headeronly +HIDDEN_NAMESPACE_END(torch, headeronly) diff --git a/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h b/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h index 334ba5b8e5b..3f357f8a06a 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +++ b/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h @@ -13,7 +13,7 @@ #endif // __has_include() && (__cplusplus >= 202002L || // (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)) -namespace torch::headeronly { +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) #if C10_HAVE_STD_BIT_CAST using std::bit_cast; @@ -43,7 +43,7 @@ bit_cast(const From& src) noexcept { #endif // C10_HAVE_STD_BIT_CAST #undef C10_HAVE_STD_BIT_CAST -} // namespace torch::headeronly +HIDDEN_NAMESPACE_END(torch, headeronly) namespace c10 { using torch::headeronly::bit_cast; diff --git a/runtime/core/portable_type/c10/torch/headeronly/util/complex.h b/runtime/core/portable_type/c10/torch/headeronly/util/complex.h index e0a356436ac..733a22d5dbb 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +++ b/runtime/core/portable_type/c10/torch/headeronly/util/complex.h @@ -590,7 +590,7 @@ struct alignas(4) complex { } // namespace c10 -namespace torch::headeronly { +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) using c10::complex; using c10::operator+; using c10::operator-; @@ -611,6 +611,6 @@ using c10::complex_literals::operator""_if; using c10::complex_literals::operator""_id; } // namespace complex_literals -} // namespace torch::headeronly +HIDDEN_NAMESPACE_END(torch, headeronly) C10_CLANG_DIAGNOSTIC_POP() diff --git a/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h b/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h index c469cc6a4f6..1e60bd85c10 100644 --- a/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +++ b/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h @@ -4,7 +4,7 @@ #include #include -namespace torch::headeronly::detail { +HIDDEN_NAMESPACE_BEGIN(torch, headeronly, detail) C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) { #if defined(__OPENCL_VERSION__) @@ -30,7 +30,7 @@ C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) { #endif } -} // namespace torch::headeronly::detail +HIDDEN_NAMESPACE_END(torch, headeronly, detail) namespace c10::detail { using torch::headeronly::detail::fp32_from_bits; diff --git a/runtime/platform/compiler.h b/runtime/platform/compiler.h index 62324699923..edd340d1fb0 100644 --- a/runtime/platform/compiler.h +++ b/runtime/platform/compiler.h @@ -161,8 +161,8 @@ // As of G3 RJ-2024.3 toolchain, zu format specifier is not supported for Xtensa #if defined(__XTENSA__) -#define ET_PRIsize_t "lu" -#define ET_PRIssize_t "ld" +#define ET_PRIsize_t "u" +#define ET_PRIssize_t "d" #else #define ET_PRIsize_t "zu" #define ET_PRIssize_t "zd" diff --git a/scripts/check_model_export_times.py b/scripts/check_model_export_times.py new file mode 100644 index 00000000000..f85a7c5a793 --- /dev/null +++ b/scripts/check_model_export_times.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import argparse +import re +from collections import defaultdict +from datetime import datetime + +import requests + + +class GithubActionsClient: + + def __init__(self, token: str): + + self.base_url = "https://api.github.com/repos/pytorch/executorch" + self.__headers = { + "Authorization": f"token {token}", + "Accept": "application/vnd.github+json", + } + + def get_runs(self, params=None): + + runs_url = f"{self.base_url}/actions/runs" + response = requests.get(runs_url, headers=self.__headers, params=params) + response.raise_for_status() + + return response.json()["workflow_runs"] + + def get_jobs(self, run_id: int, jobs_per_page: int = 100): + + jobs_url = f"{self.base_url}/actions/runs/{run_id}/jobs" + all_jobs = [] + page = 1 + + while True: + response = requests.get( + jobs_url, + headers=self.__headers, + params={"per_page": jobs_per_page, "page": page}, + ) + response.raise_for_status() + + json_response = response.json() + jobs = json_response["jobs"] + + if not jobs: # No more jobs + break + + all_jobs.extend(jobs) + + # Stop if we got fewer jobs than requested (last page) + if len(jobs) < jobs_per_page: + break + + page += 1 + + return all_jobs + + def get_job_logs(self, job_id: int): + + logs_url = f"{self.base_url}/actions/jobs/{job_id}/logs" + response = requests.get(logs_url, headers=self.__headers) + response.raise_for_status() + + return response.content.decode() + + +def extract_model_export_times(log): + + duration = re.search(r"Model export completed .* Duration: (\d+)", log) + docker_image = re.search(r"DOCKER_IMAGE:\s*(.+?)(?:\s|$)", log) + dtype = re.search(r"DTYPE=(\w+)", log) + mode = re.search(r"MODE=(\S+)", log) + runner = re.search(r"runner:\s*(\S+)", log) + + log_extract = { + "duration": duration.group(1) if duration else None, + "docker_image": docker_image.group(1) if docker_image else None, + "dtype": dtype.group(1) if dtype else None, + "mode": mode.group(1) if mode else None, + "runner": runner.group(1) if runner else None, + } + + return log_extract + + +def extract_full_model_export_times(gha_client, filters=None, run_id=None): + + if run_id: + # run_id will be a list when using nargs='+' + if isinstance(run_id, list): + all_runs = [{"id": rid} for rid in run_id] + else: + # Fallback for single string + all_runs = [{"id": run_id}] + else: + # No run_id provided, fetch runs using filters + all_runs = gha_client.get_runs(params=filters) + + model_tracker = defaultdict(list) + + for idx, run in enumerate(all_runs, 1): + + run_id_val = run["id"] + print(f"Processing run {idx}/{len(all_runs)}: ID {run_id_val}") + + try: + jobs = gha_client.get_jobs(run_id_val) + + for job in jobs: + + if job["conclusion"] == "skipped": + continue + + if not ("test-llama" in job["name"]): + continue + + try: + log = gha_client.get_job_logs(job_id=job["id"]) + + extracted_config = extract_model_export_times(log) + extracted_config["job_name"] = job["name"] + + if extracted_config["duration"]: + model_tracker[run_id_val].append(extracted_config) + + except Exception as e: + print(f" Warning: Failed to get logs for job {job['id']}: {e}") + continue + + except Exception as e: + print(f" Error: Failed to get jobs for run {run_id_val}: {e}") + continue + + return model_tracker + + +def print_results_as_table(results_dict): + """Print results as a formatted markdown table.""" + + # Extract all jobs from the defaultdict + all_jobs = [] + for run_id, jobs in results_dict.items(): + for job in jobs: + job["run_id"] = run_id # Add run_id to each job + all_jobs.append(job) + + if not all_jobs: + print("No jobs found.") + return + + # Print header + print("\n## Model Export Times\n") + print("| Run ID | Job Name | DType | Mode | Runner | Docker Image | Duration (s) |") + print("|--------|----------|-------|------|--------|--------------|--------------|") + + # Print each job + for job in all_jobs: + run_id = job.get("run_id", "N/A") + job_name = job.get("job_name", "N/A")[:60] # Truncate long names + dtype = job.get("dtype", "N/A") + mode = job.get("mode", "N/A") + runner = job.get("runner", "N/A") + docker_image = job.get("docker_image", "None") + duration = job.get("duration", "N/A") + + # Truncate docker image if too long + if docker_image and len(docker_image) > 40: + docker_image = docker_image[:37] + "..." + + print( + f"| {run_id} | {job_name} | {dtype} | {mode} | {runner} | {docker_image} | {duration} |" + ) + + # Print summary statistics + print(f"\n**Total Jobs:** {len(all_jobs)}") + + # Calculate average duration + durations = [ + int(job["duration"]) for job in all_jobs if job.get("duration", "").isdigit() + ] + if durations: + avg_duration = sum(durations) / len(durations) + print(f"**Average Duration:** {avg_duration:.1f} seconds") + print(f"**Min Duration:** {min(durations)} seconds") + print(f"**Max Duration:** {max(durations)} seconds") + + +def main(): + + parser = argparse.ArgumentParser( + description="A tool to get all model export times for the different configurations based on the githug actions runs" + ) + + parser.add_argument( + "--github_token", + metavar="executable", + type=str, + help="Your github access token", + default="", + ) + + parser.add_argument( + "--created_time", + metavar="executable", + type=str, + help="The date of the earliest github runs to include of the format YYYY-MM-DD", + default=datetime.today().strftime("%Y-%m-%d"), + ) + + parser.add_argument( + "--run_id", + metavar="RUN_ID", + type=str, + nargs="+", # Accept one or more arguments + help="One or more run IDs to extract model export times from", + default=None, + ) + + args = parser.parse_args() + + gha_client = GithubActionsClient(token=args.github_token) + + filters = {"created": f">={args.created_time}"} + + model_tracker_output = extract_full_model_export_times( + gha_client, filters=filters, run_id=args.run_id + ) + + print_results_as_table(model_tracker_output) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 45b69b8b828..19f2dd0131e 100644 --- a/setup.py +++ b/setup.py @@ -57,9 +57,6 @@ import site import subprocess import sys -import sysconfig -import tempfile - from distutils import log # type: ignore[import-not-found] from distutils.sysconfig import get_python_lib # type: ignore[import-not-found] from pathlib import Path @@ -463,84 +460,6 @@ def run(self): if self._ran_build: return - try: - # Following code is for building the Qualcomm backend. - from backends.qualcomm.scripts.download_qnn_sdk import ( - _download_qnn_sdk, - is_linux_x86, - ) - - if is_linux_x86(): - os.environ["EXECUTORCH_BUILDING_WHEEL"] = "1" - - with tempfile.TemporaryDirectory() as tmpdir: - tmp_path = Path(tmpdir) - sdk_path = _download_qnn_sdk(dst_folder=tmp_path) - - if not sdk_path: - raise RuntimeError( - "Qualcomm SDK not found, cannot build backend" - ) - - # Determine paths - prj_root = Path(__file__).parent.resolve() - build_sh = prj_root / "backends/qualcomm/scripts/build.sh" - build_root = prj_root / "build-x86" - - if not build_sh.exists(): - raise FileNotFoundError(f"{build_sh} not found") - - # Run build.sh with SDK path exported - env = dict(**os.environ) - env["QNN_SDK_ROOT"] = str(sdk_path) - subprocess.check_call( - [ - str(build_sh), - "--skip_linux_android", - "--skip_linux_embedded", - ], - env=env, - ) - - # Copy the main .so into the wheel package - so_src = ( - build_root / "backends/qualcomm/libqnn_executorch_backend.so" - ) - so_dst = Path( - self.get_ext_fullpath( - "executorch.backends.qualcomm.qnn_backend" - ) - ) - self.mkpath(str(so_dst.parent)) # ensure destination exists - self.copy_file(str(so_src), str(so_dst)) - logging.info(f"Copied Qualcomm backend: {so_src} -> {so_dst}") - - # Copy Python adaptor .so files - ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") - - so_files = [ - ( - "executorch.backends.qualcomm.python.PyQnnManagerAdaptor", - prj_root - / f"backends/qualcomm/python/PyQnnManagerAdaptor{ext_suffix}", - ), - ( - "executorch.backends.qualcomm.python.PyQnnWrapperAdaptor", - prj_root - / f"backends/qualcomm/python/PyQnnWrapperAdaptor{ext_suffix}", - ), - ] - - for module_name, so_src in so_files: - so_dst = Path(self.get_ext_fullpath(module_name)) - self.mkpath(str(so_dst.parent)) - self.copy_file(str(so_src), str(so_dst)) - logging.info(f"Copied Qualcomm backend: {so_src} -> {so_dst}") - - except ImportError: - logging.error("Fail to build Qualcomm backend") - logging.exception("Import error") - if self.editable_mode: self._ran_build = True self.run_command("build") @@ -632,7 +551,7 @@ def run(self): # package subdirectory. if self.editable_mode: # In editable mode, the package directory is the original source directory - dst_root = self.get_package_dir(".") + dst_root = self.get_package_dir("executorch") else: dst_root = os.path.join(self.build_lib, "executorch") # Create the version file. @@ -837,6 +756,11 @@ def run(self): # noqa C901 cmake_build_args += ["--target", "custom_ops_aot_lib"] cmake_build_args += ["--target", "quantized_ops_aot_lib"] + if cmake_cache.is_enabled("EXECUTORCH_BUILD_QNN"): + cmake_build_args += ["--target", "qnn_executorch_backend"] + cmake_build_args += ["--target", "PyQnnManagerAdaptor"] + cmake_build_args += ["--target", "PyQnnWrapperAdaptor"] + # Set PYTHONPATH to the location of the pip package. os.environ["PYTHONPATH"] = ( site.getsitepackages()[0] + ";" + os.environ.get("PYTHONPATH", "") @@ -918,5 +842,30 @@ def run(self): # noqa C901 is_dynamic_lib=True, dependent_cmake_flags=["EXECUTORCH_BUILD_KERNELS_LLM_AOT"], ), + BuiltFile( + src_dir="backends/cuda/runtime/", + src_name="aoti_cuda_shims.lib", + dst="executorch/data/lib/", + dependent_cmake_flags=[], + ), + BuiltFile( + src_dir="%CMAKE_CACHE_DIR%/backends/qualcomm/%BUILD_TYPE%/", + src_name="qnn_executorch_backend", + dst="executorch/backends/qualcomm/", + is_dynamic_lib=True, + dependent_cmake_flags=["EXECUTORCH_BUILD_QNN"], + ), + BuiltExtension( + src_dir="%CMAKE_CACHE_DIR%/backends/qualcomm/%BUILD_TYPE%/", + src="PyQnnManagerAdaptor.*", + modpath="executorch.backends.qualcomm.python.PyQnnManagerAdaptor", + dependent_cmake_flags=["EXECUTORCH_BUILD_QNN"], + ), + BuiltExtension( + src_dir="%CMAKE_CACHE_DIR%/backends/qualcomm/%BUILD_TYPE%/", + src="PyQnnWrapperAdaptor.*", + modpath="executorch.backends.qualcomm.python.PyQnnWrapperAdaptor", + dependent_cmake_flags=["EXECUTORCH_BUILD_QNN"], + ), ], ) diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index d6f8ded668b..fe4271e8dba 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -53,7 +53,7 @@ set(EXECUTORCH_FOUND ON) include("${CMAKE_CURRENT_LIST_DIR}/ExecuTorchTargets.cmake") set(optional_lib_list - aoti_cuda + aoti_cuda_backend flatccrt etdump bundled_program diff --git a/tools/cmake/preset/README.md b/tools/cmake/preset/README.md index 3e2bed9510e..7fd985fdaf4 100644 --- a/tools/cmake/preset/README.md +++ b/tools/cmake/preset/README.md @@ -37,6 +37,53 @@ $ cmake -DEXECUTORCH_BUILD_MPS=OFF --preset llm The cmake presets roughly map to the ExecuTorch presets and are explicitly listed in [CMakePresets.json](../../../CMakePresets.json). Note that you are encouraged to rely on presets when build locally and adding build/tests in CI — CI should do what a developer would do and nothing more! +### Using Workflows + +CMake workflow presets combine configure, build, and test steps into a single command. This is the recommended way to build ExecuTorch as it automates the entire build process. + +#### List available workflows + +```bash +$ cmake --workflow --list-presets +``` + +#### Run a complete workflow + +```bash +# Configure, build, and install LLM extension (CPU) +$ cmake --workflow --preset llm-release + +# Configure, build, and install LLM extension with CUDA (Linux only) +$ cmake --workflow --preset llm-release-cuda + +# Configure, build, and install LLM extension with Metal (macOS only) +$ cmake --workflow --preset llm-release-metal + +# Debug builds are also available +$ cmake --workflow --preset llm-debug +$ cmake --workflow --preset llm-debug-cuda +$ cmake --workflow --preset llm-debug-metal +``` + +#### Understanding workflow components + +A workflow preset typically consists of: +1. **Configure preset**: Defines CMake cache variables and build settings +2. **Build preset**: Specifies targets to build and parallel job count +3. **Workflow preset**: Orchestrates the configure and build steps + +For example, `llm-release` workflow: +- Uses `llm-release` configure preset (sets `CMAKE_BUILD_TYPE=Release`) +- Uses `llm-release-install` build preset (builds the `install` target with parallel jobs) +- Installs artifacts to `cmake-out/` directory + +#### Add a new workflow +To add a new workflow: +1. Add a configure preset, e.g. `new-workflow` +2. Add a build preset that depends on (1), e.g. `new-workflow-install` +3. You should be able to run `cmake --workflow new-workflow-install` + + ### Including ExecuTorch as Third-party Library #### Choose a built-in preset diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index 0dcec0df531..b4d6e7f31c3 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,6 +48,9 @@ define_overridable_option( define_overridable_option( EXECUTORCH_ENABLE_EVENT_TRACER "Build with ET_EVENT_TRACER_ENABLED" BOOL OFF ) +define_overridable_option( + EXECUTORCH_ENABLE_BUNDLE_IO "Build with ET_BUNDLE_IO_ENABLED" BOOL OFF +) define_overridable_option( EXECUTORCH_OPTIMIZE_SIZE "Build executorch runtime optimizing for binary size" BOOL OFF @@ -288,6 +292,12 @@ define_overridable_option( BOOL FALSE ) +define_overridable_option( + EXECUTORCH_BUILD_WHEEL_DO_NOT_USE + "On if in the wheel building process. Should only be used to guard code that is only needed for building the wheel." + BOOL + FALSE +) # ------------------------------------------------------------------------------ # Validations @@ -299,6 +309,10 @@ check_required_options_on( IF_ON EXECUTORCH_ENABLE_EVENT_TRACER REQUIRES EXECUTORCH_BUILD_DEVTOOLS ) +check_required_options_on( + IF_ON EXECUTORCH_ENABLE_BUNDLE_IO REQUIRES EXECUTORCH_BUILD_DEVTOOLS +) + check_required_options_on( IF_ON EXECUTORCH_BUILD_EXECUTOR_RUNNER REQUIRES EXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL diff --git a/tools/cmake/preset/llm.cmake b/tools/cmake/preset/llm.cmake index 231a25f0c1e..86a1c0dbe1b 100644 --- a/tools/cmake/preset/llm.cmake +++ b/tools/cmake/preset/llm.cmake @@ -16,13 +16,21 @@ set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TENSOR ON) set_overridable_option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED ON) set_overridable_option(EXECUTORCH_BUILD_XNNPACK ON) -# Turn on the quantized and LLM kernels unless on windows cuda build which -# currently doesn't support this due to using msvc. -if(NOT (EXECUTORCH_BUILD_CUDA AND (CMAKE_SYSTEM_NAME STREQUAL "Windows" - OR CMAKE_SYSTEM_NAME STREQUAL "WIN32")) +# Turn on the quantized and LLM kernels unless on Windows with MSVC build since +# they don't currently compile. +if(NOT ((CMAKE_SYSTEM_NAME STREQUAL "Windows" OR CMAKE_SYSTEM_NAME STREQUAL + "WIN32") AND MSVC) ) set_overridable_option(EXECUTORCH_BUILD_KERNELS_QUANTIZED ON) set_overridable_option(EXECUTORCH_BUILD_KERNELS_LLM ON) +else() + if(NOT EXECUTORCH_BUILD_CUDA) + message( + WARNING + "The llm custom kernels and the quantized kernels will not be built when using MSVC on Windows. " + "If you need them (since you appear to be building for CPU) try building with -T ClangCL" + ) + endif() endif() if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") diff --git a/tools/cmake/preset/pybind.cmake b/tools/cmake/preset/pybind.cmake index c71c10ad01f..699a7c50358 100644 --- a/tools/cmake/preset/pybind.cmake +++ b/tools/cmake/preset/pybind.cmake @@ -22,6 +22,7 @@ set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_MODULE ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP ON) +set_overridable_option(EXECUTORCH_BUILD_WHEEL_DO_NOT_USE ON) # TODO(larryliu0820): Temporarily disable building llm_runner for Windows wheel # due to the issue of tokenizer file path length limitation. @@ -35,6 +36,9 @@ elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") set_overridable_option(EXECUTORCH_BUILD_EXTENSION_TRAINING ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64|i.86)$") + set_overridable_option(EXECUTORCH_BUILD_QNN OFF) + endif() elseif(CMAKE_SYSTEM_NAME STREQUAL "Windows" OR CMAKE_SYSTEM_NAME STREQUAL "WIN32" ) diff --git a/torch_pin.py b/torch_pin.py index 5e54c848d13..ffdde14bdb5 100644 --- a/torch_pin.py +++ b/torch_pin.py @@ -1,2 +1,2 @@ TORCH_VERSION = "2.10.0" -NIGHTLY_VERSION = "dev20251015" +NIGHTLY_VERSION = "dev20251025" diff --git a/util/python_profiler.py b/util/python_profiler.py index c62b0ffafe0..33395afaedf 100644 --- a/util/python_profiler.py +++ b/util/python_profiler.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,14 +13,17 @@ import re from pstats import Stats -from snakeviz.stats import json_stats, table_rows # type: ignore[import-not-found] +from snakeviz.stats import ( # type: ignore[import-not-found,import-untyped] + json_stats, + table_rows, +) from tornado import template # type: ignore[import-not-found] module_found = True snakeviz_templates_dir: str = "" try: - import snakeviz # type: ignore[import-not-found] + import snakeviz # type: ignore[import-not-found,import-untyped] snakeviz_dir = os.path.dirname(os.path.abspath(snakeviz.__file__)) snakeviz_templates_dir = os.path.join(snakeviz_dir, "templates") diff --git a/website/index.html b/website/index.html index 891c357f296..ace59ed8b5e 100644 --- a/website/index.html +++ b/website/index.html @@ -25,22 +25,25 @@ @@ -54,7 +57,7 @@ -
+

Deploy PyTorch models directly to edge devices. Text, vision, and audio AI with privacy-preserving, real-time inference — no cloud required.

@@ -661,7 +664,7 @@

High

Run complex multimodal LLMs with simplified C++ interfaces

-
+

Multimodal Runner - Text + Vision + Audio in One API diff --git a/website/style.css b/website/style.css index 9c60bd3734e..5e1e1e53baa 100644 --- a/website/style.css +++ b/website/style.css @@ -92,12 +92,12 @@ nav { display: flex; justify-content: space-between; align-items: center; - padding: 1rem 0; + padding: 0.75rem 0; position: relative; z-index: 1; } -.logo { +.nav-logo { display: flex; align-items: center; gap: 0.75rem; @@ -106,11 +106,15 @@ nav { color: var(--text-dark); } -.logo img { +.nav-logo img { height: 40px; filter: drop-shadow(0 0 2px var(--bg-gray)); } +.nav-logo span { + padding-right: 1rem; +} + .nav-links { display: flex; gap: 2rem; @@ -132,6 +136,10 @@ nav { border-bottom: 2px solid var(--primary); } +.nav-overview { + display: none; +} + /* nav search */ .nav-search { margin-left: 2rem; @@ -181,7 +189,7 @@ section.alt { .title_banner-container { position: absolute; - left:150px; + left: 150px; bottom: 60px; display: flex; gap: 12px; @@ -199,7 +207,7 @@ section.alt { font-weight: 800; display: flex; align-items: center; - gap: 12px; + gap: 8px; } /* Hero */ @@ -1014,36 +1022,61 @@ footer a { } /* Responsive */ -@media (max-width: 900px) { +@media (max-width: 1024px) { .container { padding: 0 1rem; } - .nav-search { - display: none; - } + .grid-2x2 { grid-template-columns: 1fr; } - .title_banner { - height: 300px; - } - .title_banner-container { - left: 32%; - transform: translateX(-50%); - padding: 1rem; + .nav-toggle { + display: block; + margin-left: auto; + flex-shrink: 0; } - .logo-text-container { - font-size: 3rem; + .nav-links { + display: none; } - .title_banner-logo { + .nav-links { display: none; + flex-direction: column; + background: rgba(48,48,48,0.98); + position: fixed; + top: 62px; + left: 0; + width: 100%; + box-shadow: 0 4px 12px rgba(0,0,0,0.5); + padding: 0; + gap: 0; + z-index: 9999; + max-height: calc(100vh - 62px); + overflow-y: auto; + } + .nav-links li { + width: 100%; + border-bottom: 1px solid rgba(255,255,255,0.1); + } + .nav-links li a { + padding: 1rem 1.5rem; + min-height: 48px; + display: flex; + align-items: center; + width: 100%; + border-bottom: none; + } + .nav-links li:last-child { + border-bottom: none; + } + .nav-links.open { + display: flex; + } + .nav-logo { + gap: 0.5rem; } } @media (max-width: 768px) { - .container { - padding: 0 1rem; - } section { padding: 3rem 0; } @@ -1104,25 +1137,6 @@ footer a { .flow-arrow { transform: rotate(90deg); } - .nav-links { - display: none; - } - .nav-search { - display: none; - } - .title_banner { - height: 250px; - } - .title_banner-container { - left: 32%; - transform: translateX(-50%); - } - .logo-text-container { - font-size: 2.5rem; - } - .title_banner-logo { - display: none; - } .card-text { font-size: 1rem; } @@ -1139,6 +1153,13 @@ footer a { code { font-size: 0.8rem !important; } + .logo-text-container { + font-size: 3rem; + position: absolute; + left: -100%; + bottom: 50%; + gap: 6px; + } } @media (max-width: 700px) { @@ -1150,22 +1171,9 @@ footer a { } .nav-content { flex-wrap: nowrap; - gap: 0.5rem; - padding: 0.75rem 0; + padding: 0.25rem 0; overflow: visible; } - .logo { - font-size: 1.2rem; - gap: 0.5rem; - flex-shrink: 1; - min-width: 0; - } - .logo span { - display: none; - } - .logo img { - height: 32px; - } .nav-links { display: none; flex-direction: column; @@ -1199,20 +1207,38 @@ footer a { .nav-links.open { display: flex; } - .nav-toggle { - display: block; - margin-left: auto; - flex-shrink: 0; + .nav-logo { + font-size: 1.2rem; + gap: 0.25rem; + flex-shrink: 1; } - .nav-search { - display: none; + .nav-logo img { + height: 32px; } .logo-text-container { - font-size: 2rem; - gap: 8px; + position: absolute; + left: -150%; + bottom: 60%; + font-size: 2.75rem; + gap: 4px; + } +} + +@media (max-width: 650px) { + .title_banner { + height: 300px; } .title_banner-logo { - height: 40px; + height: 68px; + } + .logo-text-container { + display: flex; + align-items: center; + position: absolute; + left: -175%; + bottom: 25%; + font-size: 2.5rem; + font-weight: 800; } } @@ -1232,14 +1258,21 @@ footer a { font-size: 1.2rem; padding: 1rem; } - .title_banner { - height: 200px; + .nav-overview { + display: block; } - .logo-text-container { - font-size: 1.5rem; + .nav-logo { + gap: 0rem; + } + /* Banner and logo sizing */ + .title_banner { + height: 275px; } .title_banner-logo { - display: none; + height: 64px; + } + .logo-text-container { + font-size: 2.25rem; } .btn { padding: 0.75rem 1.5rem; @@ -1275,3 +1308,9 @@ footer a { padding: 0 0.5rem; } } + +@media (max-width: 400px) { + .nav-logo span { + display: none; + } +}