diff --git a/CLAUDE.md b/CLAUDE.md index ad69ff5..a7007ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -192,14 +192,15 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?; ``` **DSL Features:** -- Block/grid indices: `block_idx_x()`, `thread_idx_x()`, `block_dim_x()`, `grid_dim_x()`, etc. +- Block/grid indices: `block_idx_x()`, `thread_idx_x()`, `block_dim_x()`, `grid_dim_x()`, `warp_size()`, etc. - Control flow: `if/else`, `match` → switch/case, early `return` - Loops: `for i in 0..n`, `while cond`, `loop` with `break`/`continue` -- Stencil intrinsics: `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`, `pos.at(buf, dx, dy)` +- Stencil intrinsics (2D): `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`, `pos.at(buf, dx, dy)` +- Stencil intrinsics (3D): `pos.up(buf)`, `pos.down(buf)`, `pos.at(buf, dx, dy, dz)` for volumetric kernels - Shared memory: `__shared__` arrays and tiles with `SharedMemoryConfig` - Struct literals: `Point { x: 1.0, y: 2.0 }` → C compound literals - Reference expressions: `&arr[idx]` → pointer to element with automatic `->` operator for field access -- 45+ GPU intrinsics (atomics, warp ops, sync, math) +- 120+ GPU intrinsics across 13 categories (synchronization, atomics, math, trig, hyperbolic, exponential, classification, warp, bit manipulation, memory, special, index, timing) **Ring Kernel Features:** - Persistent message loop with ControlBlock lifecycle management @@ -282,7 +283,7 @@ Main crate (`ringkernel`) features: - ringkernel-core: 65 tests - ringkernel-cpu: 11 tests - ringkernel-cuda: 6 GPU execution tests -- ringkernel-cuda-codegen: 143 tests (loops, shared memory, ring kernels, K2K, reference expressions) +- ringkernel-cuda-codegen: 171 tests (loops, shared memory, ring kernels, K2K, reference expressions, 120+ GPU intrinsics) - ringkernel-wgpu-codegen: 50 tests (types, intrinsics, transpiler, validation) - ringkernel-derive: 14 macro tests - ringkernel-wavesim: 49 tests (including educational modes) diff --git a/crates/ringkernel-cuda-codegen/README.md b/crates/ringkernel-cuda-codegen/README.md index a52bec2..09a46e7 100644 --- a/crates/ringkernel-cuda-codegen/README.md +++ b/crates/ringkernel-cuda-codegen/README.md @@ -7,7 +7,7 @@ Rust-to-CUDA transpiler for RingKernel GPU kernels. This crate enables writing GPU kernels in a restricted Rust DSL and transpiling them to CUDA C code. It supports three kernel types: 1. **Global Kernels** - Standard CUDA `__global__` functions -2. **Stencil Kernels** - Tile-based kernels with `GridPos` abstraction +2. **Stencil Kernels** - Tile-based kernels with `GridPos` abstraction (2D and 3D) 3. **Ring Kernels** - Persistent actor kernels with message loops ## Installation @@ -39,11 +39,12 @@ let cuda_code = transpile_global_kernel(&func)?; ## Stencil Kernels -For grid-based computations with neighbor access: +For grid-based computations with neighbor access (2D and 3D): ```rust -use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig}; +use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig, Grid}; +// 2D stencil let func: syn::ItemFn = parse_quote! { fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) { let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) @@ -53,10 +54,25 @@ let func: syn::ItemFn = parse_quote! { }; let config = StencilConfig::new("fdtd") + .with_grid(Grid::Grid2D) .with_tile_size(16, 16) .with_halo(1); let cuda_code = transpile_stencil_kernel(&func, &config)?; + +// 3D stencil with up/down neighbors +let func_3d: syn::ItemFn = parse_quote! { + fn laplacian_3d(p: &[f32], out: &mut [f32], pos: GridPos) { + let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) + + pos.up(p) + pos.down(p) - 6.0 * p[pos.idx()]; + out[pos.idx()] = lap; + } +}; + +let config_3d = StencilConfig::new("laplacian") + .with_grid(Grid::Grid3D) + .with_tile_size(8, 8) + .with_halo(1); ``` ## Ring Kernels @@ -86,35 +102,149 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?; ## DSL Reference ### Thread/Block Indices -- `thread_idx_x()`, `thread_idx_y()`, `thread_idx_z()` -- `block_idx_x()`, `block_idx_y()`, `block_idx_z()` -- `block_dim_x()`, `block_dim_y()`, `block_dim_z()` -- `grid_dim_x()`, `grid_dim_y()`, `grid_dim_z()` +- `thread_idx_x()`, `thread_idx_y()`, `thread_idx_z()` → `threadIdx.x/y/z` +- `block_idx_x()`, `block_idx_y()`, `block_idx_z()` → `blockIdx.x/y/z` +- `block_dim_x()`, `block_dim_y()`, `block_dim_z()` → `blockDim.x/y/z` +- `grid_dim_x()`, `grid_dim_y()`, `grid_dim_z()` → `gridDim.x/y/z` +- `warp_size()` → `warpSize` -### Stencil Intrinsics +### Stencil Intrinsics (2D) - `pos.idx()` - Linear index -- `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)` +- `pos.north(buf)`, `pos.south(buf)` - Y-axis neighbors +- `pos.east(buf)`, `pos.west(buf)` - X-axis neighbors - `pos.at(buf, dx, dy)` - Relative offset access -### Synchronization -- `sync_threads()` - Block-level barrier -- `thread_fence()` - Device memory fence -- `thread_fence_block()` - Block memory fence - -### Atomics -- `atomic_add(ptr, val)`, `atomic_sub(ptr, val)` -- `atomic_min(ptr, val)`, `atomic_max(ptr, val)` -- `atomic_exchange(ptr, val)`, `atomic_cas(ptr, compare, val)` +### Stencil Intrinsics (3D) +- `pos.up(buf)`, `pos.down(buf)` - Z-axis neighbors +- `pos.at(buf, dx, dy, dz)` - 3D relative offset access -### Math Functions -- `sqrt()`, `abs()`, `floor()`, `ceil()`, `round()` -- `sin()`, `cos()`, `tan()`, `exp()`, `log()` -- `powf()`, `min()`, `max()`, `mul_add()` +### Synchronization +- `sync_threads()` → `__syncthreads()` - Block-level barrier +- `sync_threads_count(pred)` → `__syncthreads_count()` - Count threads with predicate +- `sync_threads_and(pred)` → `__syncthreads_and()` - AND of predicate +- `sync_threads_or(pred)` → `__syncthreads_or()` - OR of predicate +- `thread_fence()` → `__threadfence()` - Device memory fence +- `thread_fence_block()` → `__threadfence_block()` - Block memory fence +- `thread_fence_system()` → `__threadfence_system()` - System memory fence + +### Atomic Operations (Integer) +- `atomic_add(ptr, val)` → `atomicAdd` +- `atomic_sub(ptr, val)` → `atomicSub` +- `atomic_min(ptr, val)` → `atomicMin` +- `atomic_max(ptr, val)` → `atomicMax` +- `atomic_exchange(ptr, val)` → `atomicExch` +- `atomic_cas(ptr, compare, val)` → `atomicCAS` +- `atomic_and(ptr, val)` → `atomicAnd` +- `atomic_or(ptr, val)` → `atomicOr` +- `atomic_xor(ptr, val)` → `atomicXor` +- `atomic_inc(ptr, val)` → `atomicInc` (increment with wrap) +- `atomic_dec(ptr, val)` → `atomicDec` (decrement with wrap) + +### Basic Math Functions +- `sqrt()`, `rsqrt()` - Square root, reciprocal sqrt +- `abs()`, `fabs()` - Absolute value +- `floor()`, `ceil()`, `round()`, `trunc()` - Rounding +- `fma()`, `mul_add()` - Fused multiply-add +- `fmin()`, `fmax()` - Minimum, maximum +- `fmod()`, `remainder()` - Modulo operations +- `copysign()` - Copy sign +- `cbrt()` - Cube root +- `hypot()` - Hypotenuse + +### Trigonometric Functions +- `sin()`, `cos()`, `tan()` - Basic trig +- `asin()`, `acos()`, `atan()`, `atan2()` - Inverse trig +- `sincos()` - Combined sine and cosine +- `sinpi()`, `cospi()` - Sin/cos of π*x + +### Hyperbolic Functions +- `sinh()`, `cosh()`, `tanh()` - Hyperbolic +- `asinh()`, `acosh()`, `atanh()` - Inverse hyperbolic + +### Exponential and Logarithmic Functions +- `exp()`, `exp2()`, `exp10()`, `expm1()` - Exponentials +- `log()`, `ln()`, `log2()`, `log10()`, `log1p()` - Logarithms +- `pow()`, `powf()`, `powi()` - Power +- `ldexp()`, `scalbn()` - Load/scale exponent +- `ilogb()` - Extract exponent +- `erf()`, `erfc()`, `erfinv()`, `erfcinv()` - Error functions +- `lgamma()`, `tgamma()` - Gamma functions + +### Classification Functions +- `is_nan()`, `isnan()` → `isnan` +- `is_infinite()`, `isinf()` → `isinf` +- `is_finite()`, `isfinite()` → `isfinite` +- `is_normal()`, `isnormal()` → `isnormal` +- `signbit()` - Check sign bit +- `nextafter()` - Next representable value +- `fdim()` - Positive difference ### Warp Operations -- `warp_shuffle(val, lane)`, `warp_shuffle_up(val, delta)` -- `warp_shuffle_down(val, delta)`, `warp_shuffle_xor(val, mask)` -- `warp_ballot(pred)`, `warp_all(pred)`, `warp_any(pred)` +- `warp_active_mask()` → `__activemask()` - Active lane mask +- `warp_shfl(mask, val, lane)` → `__shfl_sync` - Shuffle +- `warp_shfl_up(mask, val, delta)` → `__shfl_up_sync` +- `warp_shfl_down(mask, val, delta)` → `__shfl_down_sync` +- `warp_shfl_xor(mask, val, lane_mask)` → `__shfl_xor_sync` +- `warp_ballot(mask, pred)` → `__ballot_sync` +- `warp_all(mask, pred)` → `__all_sync` +- `warp_any(mask, pred)` → `__any_sync` + +### Warp Match Operations (Volta+) +- `warp_match_any(mask, val)` → `__match_any_sync` +- `warp_match_all(mask, val)` → `__match_all_sync` + +### Warp Reduce Operations (SM 8.0+) +- `warp_reduce_add(mask, val)` → `__reduce_add_sync` +- `warp_reduce_min(mask, val)` → `__reduce_min_sync` +- `warp_reduce_max(mask, val)` → `__reduce_max_sync` +- `warp_reduce_and(mask, val)` → `__reduce_and_sync` +- `warp_reduce_or(mask, val)` → `__reduce_or_sync` +- `warp_reduce_xor(mask, val)` → `__reduce_xor_sync` + +### Bit Manipulation +- `popc()`, `popcount()`, `count_ones()` → `__popc` - Population count +- `clz()`, `leading_zeros()` → `__clz` - Count leading zeros +- `ctz()`, `trailing_zeros()` → `__ffs - 1` - Count trailing zeros +- `ffs()` → `__ffs` - Find first set +- `brev()`, `reverse_bits()` → `__brev` - Bit reverse +- `byte_perm()` → `__byte_perm` - Byte permutation +- `funnel_shift_left()` → `__funnelshift_l` +- `funnel_shift_right()` → `__funnelshift_r` + +### Memory Operations +- `ldg(ptr)`, `load_global(ptr)` → `__ldg` - Read-only cache load +- `prefetch_l1(ptr)` → `__prefetch_l1` - L1 prefetch +- `prefetch_l2(ptr)` → `__prefetch_l2` - L2 prefetch + +### Special Functions +- `rcp()`, `recip()` → `__frcp_rn` - Fast reciprocal +- `fast_div()` → `__fdividef` - Fast division +- `saturate()`, `clamp_01()` → `__saturatef` - Saturate to [0,1] +- `j0()`, `j1()`, `jn()` - Bessel functions of first kind +- `y0()`, `y1()`, `yn()` - Bessel functions of second kind +- `normcdf()`, `normcdfinv()` - Normal CDF +- `cyl_bessel_i0()`, `cyl_bessel_i1()` - Cylindrical Bessel functions + +### Clock and Timing +- `clock()` → `clock()` - 32-bit clock counter +- `clock64()` → `clock64()` - 64-bit clock counter +- `nanosleep(ns)` → `__nanosleep` - Sleep for nanoseconds + +### RingContext Methods +- `ctx.thread_id()` → `threadIdx.x` +- `ctx.block_id()` → `blockIdx.x` +- `ctx.global_thread_id()` → `(blockIdx.x * blockDim.x + threadIdx.x)` +- `ctx.sync_threads()` → `__syncthreads()` +- `ctx.lane_id()` → `(threadIdx.x % 32)` +- `ctx.warp_id()` → `(threadIdx.x / 32)` + +### Ring Kernel Intrinsics +- `is_active()`, `should_terminate()`, `mark_terminated()` +- `messages_processed()`, `input_queue_size()`, `output_queue_size()` +- `input_queue_empty()`, `output_queue_empty()`, `enqueue_response(&resp)` +- `hlc_tick()`, `hlc_update(ts)`, `hlc_now()` - HLC operations +- `k2k_send(target, &msg)`, `k2k_try_recv()` - K2K messaging +- `k2k_has_message()`, `k2k_peek()`, `k2k_pending_count()` ## Type Mapping @@ -130,13 +260,33 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?; | `&[T]` | `const T* __restrict__` | | `&mut [T]` | `T* __restrict__` | +## Intrinsic Count + +The transpiler supports **120+ GPU intrinsics** across 13 categories: + +| Category | Count | Examples | +|----------|-------|----------| +| Synchronization | 7 | `sync_threads`, `thread_fence` | +| Atomics | 11 | `atomic_add`, `atomic_cas`, `atomic_and` | +| Math | 16 | `sqrt`, `fma`, `cbrt`, `hypot` | +| Trigonometric | 11 | `sin`, `asin`, `atan2`, `sincos` | +| Hyperbolic | 6 | `sinh`, `asinh` | +| Exponential | 18 | `exp`, `log2`, `erf`, `gamma` | +| Classification | 8 | `isnan`, `isfinite`, `signbit` | +| Warp | 16 | `warp_shfl`, `warp_reduce_add`, `warp_match_any` | +| Bit Manipulation | 8 | `popc`, `clz`, `brev`, `funnel_shift_left` | +| Memory | 3 | `ldg`, `prefetch_l1` | +| Special | 13 | `rcp`, `saturate`, `normcdf` | +| Index | 13 | `thread_idx_x`, `warp_size` | +| Timing | 3 | `clock`, `clock64`, `nanosleep` | + ## Testing ```bash cargo test -p ringkernel-cuda-codegen ``` -The crate includes 143 tests covering all kernel types and language features. +The crate includes 171 tests covering all kernel types, intrinsics, and language features. ## License diff --git a/crates/ringkernel-cuda-codegen/src/dsl.rs b/crates/ringkernel-cuda-codegen/src/dsl.rs index 11ae806..e966674 100644 --- a/crates/ringkernel-cuda-codegen/src/dsl.rs +++ b/crates/ringkernel-cuda-codegen/src/dsl.rs @@ -21,13 +21,44 @@ //! ```ignore //! sync_threads(); // -> __syncthreads() //! ``` +//! +//! # Math Functions +//! +//! All standard math functions are available with CPU fallbacks: +//! - Trigonometric: sin, cos, tan, asin, acos, atan, atan2 +//! - Hyperbolic: sinh, cosh, tanh, asinh, acosh, atanh +//! - Exponential: exp, exp2, exp10, expm1, log, log2, log10, log1p +//! - Power: pow, sqrt, rsqrt, cbrt +//! - Rounding: floor, ceil, round, trunc +//! - Comparison: fmin, fmax, fdim, copysign +//! +//! # Warp Operations +//! +//! ```ignore +//! let mask = warp_active_mask(); // Get active lane mask +//! let result = warp_reduce_add(mask, value); // Warp-level sum +//! let shuffled = warp_shfl(mask, value, lane); // Shuffle +//! ``` +//! +//! # Bit Manipulation +//! +//! ```ignore +//! let bits = popc(x); // Count set bits +//! let zeros = clz(x); // Count leading zeros +//! let rev = brev(x); // Reverse bits +//! ``` + +use std::sync::atomic::{fence, Ordering}; + +// ============================================================================ +// Thread/Block Index Functions +// ============================================================================ /// Get the thread index within a block (x dimension). /// Transpiles to: `threadIdx.x` #[inline] pub fn thread_idx_x() -> i32 { - // CPU fallback: single-threaded execution uses index 0 - 0 + 0 // CPU fallback: single-threaded execution } /// Get the thread index within a block (y dimension). @@ -107,6 +138,17 @@ pub fn grid_dim_z() -> i32 { 1 } +/// Get the warp size (always 32 on NVIDIA GPUs). +/// Transpiles to: `warpSize` +#[inline] +pub fn warp_size() -> i32 { + 32 +} + +// ============================================================================ +// Synchronization Functions +// ============================================================================ + /// Synchronize all threads in a block. /// Transpiles to: `__syncthreads()` #[inline] @@ -114,18 +156,790 @@ pub fn sync_threads() { // CPU fallback: no-op (single-threaded) } +/// Synchronize threads and count predicate. +/// Transpiles to: `__syncthreads_count(predicate)` +#[inline] +pub fn sync_threads_count(predicate: bool) -> i32 { + if predicate { 1 } else { 0 } +} + +/// Synchronize threads with AND of predicate. +/// Transpiles to: `__syncthreads_and(predicate)` +#[inline] +pub fn sync_threads_and(predicate: bool) -> i32 { + if predicate { 1 } else { 0 } +} + +/// Synchronize threads with OR of predicate. +/// Transpiles to: `__syncthreads_or(predicate)` +#[inline] +pub fn sync_threads_or(predicate: bool) -> i32 { + if predicate { 1 } else { 0 } +} + /// Thread memory fence. /// Transpiles to: `__threadfence()` #[inline] pub fn thread_fence() { - std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst); + fence(Ordering::SeqCst); } /// Block-level memory fence. /// Transpiles to: `__threadfence_block()` #[inline] pub fn thread_fence_block() { - std::sync::atomic::fence(std::sync::atomic::Ordering::Release); + fence(Ordering::Release); +} + +/// System-wide memory fence. +/// Transpiles to: `__threadfence_system()` +#[inline] +pub fn thread_fence_system() { + fence(Ordering::SeqCst); +} + +// ============================================================================ +// Atomic Operations (CPU fallbacks - not thread-safe!) +// ============================================================================ + +/// Atomic add. Transpiles to: `atomicAdd(addr, val)` +/// WARNING: CPU fallback is NOT thread-safe! +#[inline] +pub fn atomic_add(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr += val; + old +} + +/// Atomic add for f32. Transpiles to: `atomicAdd(addr, val)` +#[inline] +pub fn atomic_add_f32(addr: &mut f32, val: f32) -> f32 { + let old = *addr; + *addr += val; + old +} + +/// Atomic subtract. Transpiles to: `atomicSub(addr, val)` +#[inline] +pub fn atomic_sub(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr -= val; + old +} + +/// Atomic minimum. Transpiles to: `atomicMin(addr, val)` +#[inline] +pub fn atomic_min(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr = old.min(val); + old +} + +/// Atomic maximum. Transpiles to: `atomicMax(addr, val)` +#[inline] +pub fn atomic_max(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr = old.max(val); + old +} + +/// Atomic exchange. Transpiles to: `atomicExch(addr, val)` +#[inline] +pub fn atomic_exchange(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr = val; + old +} + +/// Atomic compare and swap. Transpiles to: `atomicCAS(addr, compare, val)` +#[inline] +pub fn atomic_cas(addr: &mut i32, compare: i32, val: i32) -> i32 { + let old = *addr; + if old == compare { + *addr = val; + } + old +} + +/// Atomic AND. Transpiles to: `atomicAnd(addr, val)` +#[inline] +pub fn atomic_and(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr &= val; + old +} + +/// Atomic OR. Transpiles to: `atomicOr(addr, val)` +#[inline] +pub fn atomic_or(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr |= val; + old +} + +/// Atomic XOR. Transpiles to: `atomicXor(addr, val)` +#[inline] +pub fn atomic_xor(addr: &mut i32, val: i32) -> i32 { + let old = *addr; + *addr ^= val; + old +} + +/// Atomic increment with wrap. Transpiles to: `atomicInc(addr, val)` +#[inline] +pub fn atomic_inc(addr: &mut u32, val: u32) -> u32 { + let old = *addr; + *addr = if old >= val { 0 } else { old + 1 }; + old +} + +/// Atomic decrement with wrap. Transpiles to: `atomicDec(addr, val)` +#[inline] +pub fn atomic_dec(addr: &mut u32, val: u32) -> u32 { + let old = *addr; + *addr = if old == 0 || old > val { val } else { old - 1 }; + old +} + +// ============================================================================ +// Basic Math Functions +// ============================================================================ + +/// Square root. Transpiles to: `sqrtf(x)` +#[inline] +pub fn sqrt(x: f32) -> f32 { + x.sqrt() +} + +/// Reciprocal square root. Transpiles to: `rsqrtf(x)` +#[inline] +pub fn rsqrt(x: f32) -> f32 { + 1.0 / x.sqrt() +} + +/// Absolute value for f32. Transpiles to: `fabsf(x)` +#[inline] +pub fn fabs(x: f32) -> f32 { + x.abs() +} + +/// Floor. Transpiles to: `floorf(x)` +#[inline] +pub fn floor(x: f32) -> f32 { + x.floor() +} + +/// Ceiling. Transpiles to: `ceilf(x)` +#[inline] +pub fn ceil(x: f32) -> f32 { + x.ceil() +} + +/// Round to nearest. Transpiles to: `roundf(x)` +#[inline] +pub fn round(x: f32) -> f32 { + x.round() +} + +/// Truncate toward zero. Transpiles to: `truncf(x)` +#[inline] +pub fn trunc(x: f32) -> f32 { + x.trunc() +} + +/// Fused multiply-add. Transpiles to: `fmaf(a, b, c)` +#[inline] +pub fn fma(a: f32, b: f32, c: f32) -> f32 { + a.mul_add(b, c) +} + +/// Minimum. Transpiles to: `fminf(a, b)` +#[inline] +pub fn fmin(a: f32, b: f32) -> f32 { + a.min(b) +} + +/// Maximum. Transpiles to: `fmaxf(a, b)` +#[inline] +pub fn fmax(a: f32, b: f32) -> f32 { + a.max(b) +} + +/// Floating-point modulo. Transpiles to: `fmodf(x, y)` +#[inline] +pub fn fmod(x: f32, y: f32) -> f32 { + x % y +} + +/// Remainder. Transpiles to: `remainderf(x, y)` +#[inline] +pub fn remainder(x: f32, y: f32) -> f32 { + x - (x / y).round() * y +} + +/// Copy sign. Transpiles to: `copysignf(x, y)` +#[inline] +pub fn copysign(x: f32, y: f32) -> f32 { + x.copysign(y) +} + +/// Cube root. Transpiles to: `cbrtf(x)` +#[inline] +pub fn cbrt(x: f32) -> f32 { + x.cbrt() +} + +/// Hypotenuse. Transpiles to: `hypotf(x, y)` +#[inline] +pub fn hypot(x: f32, y: f32) -> f32 { + x.hypot(y) +} + +// ============================================================================ +// Trigonometric Functions +// ============================================================================ + +/// Sine. Transpiles to: `sinf(x)` +#[inline] +pub fn sin(x: f32) -> f32 { + x.sin() +} + +/// Cosine. Transpiles to: `cosf(x)` +#[inline] +pub fn cos(x: f32) -> f32 { + x.cos() +} + +/// Tangent. Transpiles to: `tanf(x)` +#[inline] +pub fn tan(x: f32) -> f32 { + x.tan() +} + +/// Arcsine. Transpiles to: `asinf(x)` +#[inline] +pub fn asin(x: f32) -> f32 { + x.asin() +} + +/// Arccosine. Transpiles to: `acosf(x)` +#[inline] +pub fn acos(x: f32) -> f32 { + x.acos() +} + +/// Arctangent. Transpiles to: `atanf(x)` +#[inline] +pub fn atan(x: f32) -> f32 { + x.atan() +} + +/// Two-argument arctangent. Transpiles to: `atan2f(y, x)` +#[inline] +pub fn atan2(y: f32, x: f32) -> f32 { + y.atan2(x) +} + +/// Sine and cosine together. Transpiles to: `sincosf(x, &s, &c)` +#[inline] +pub fn sincos(x: f32) -> (f32, f32) { + (x.sin(), x.cos()) +} + +/// Sine of pi*x. Transpiles to: `sinpif(x)` +#[inline] +pub fn sinpi(x: f32) -> f32 { + (x * std::f32::consts::PI).sin() +} + +/// Cosine of pi*x. Transpiles to: `cospif(x)` +#[inline] +pub fn cospi(x: f32) -> f32 { + (x * std::f32::consts::PI).cos() +} + +// ============================================================================ +// Hyperbolic Functions +// ============================================================================ + +/// Hyperbolic sine. Transpiles to: `sinhf(x)` +#[inline] +pub fn sinh(x: f32) -> f32 { + x.sinh() +} + +/// Hyperbolic cosine. Transpiles to: `coshf(x)` +#[inline] +pub fn cosh(x: f32) -> f32 { + x.cosh() +} + +/// Hyperbolic tangent. Transpiles to: `tanhf(x)` +#[inline] +pub fn tanh(x: f32) -> f32 { + x.tanh() +} + +/// Inverse hyperbolic sine. Transpiles to: `asinhf(x)` +#[inline] +pub fn asinh(x: f32) -> f32 { + x.asinh() +} + +/// Inverse hyperbolic cosine. Transpiles to: `acoshf(x)` +#[inline] +pub fn acosh(x: f32) -> f32 { + x.acosh() +} + +/// Inverse hyperbolic tangent. Transpiles to: `atanhf(x)` +#[inline] +pub fn atanh(x: f32) -> f32 { + x.atanh() +} + +// ============================================================================ +// Exponential and Logarithmic Functions +// ============================================================================ + +/// Exponential (base e). Transpiles to: `expf(x)` +#[inline] +pub fn exp(x: f32) -> f32 { + x.exp() +} + +/// Exponential (base 2). Transpiles to: `exp2f(x)` +#[inline] +pub fn exp2(x: f32) -> f32 { + x.exp2() +} + +/// Exponential (base 10). Transpiles to: `exp10f(x)` +#[inline] +pub fn exp10(x: f32) -> f32 { + (x * std::f32::consts::LN_10).exp() +} + +/// exp(x) - 1 (accurate for small x). Transpiles to: `expm1f(x)` +#[inline] +pub fn expm1(x: f32) -> f32 { + x.exp_m1() +} + +/// Natural logarithm (base e). Transpiles to: `logf(x)` +#[inline] +pub fn log(x: f32) -> f32 { + x.ln() +} + +/// Logarithm (base 2). Transpiles to: `log2f(x)` +#[inline] +pub fn log2(x: f32) -> f32 { + x.log2() +} + +/// Logarithm (base 10). Transpiles to: `log10f(x)` +#[inline] +pub fn log10(x: f32) -> f32 { + x.log10() +} + +/// log(1 + x) (accurate for small x). Transpiles to: `log1pf(x)` +#[inline] +pub fn log1p(x: f32) -> f32 { + x.ln_1p() +} + +/// Power. Transpiles to: `powf(x, y)` +#[inline] +pub fn pow(x: f32, y: f32) -> f32 { + x.powf(y) +} + +/// Load exponent. Transpiles to: `ldexpf(x, exp)` +#[inline] +pub fn ldexp(x: f32, exp: i32) -> f32 { + x * 2.0_f32.powi(exp) +} + +/// Scale by power of 2. Transpiles to: `scalbnf(x, n)` +#[inline] +pub fn scalbn(x: f32, n: i32) -> f32 { + x * 2.0_f32.powi(n) +} + +/// Extract exponent. Transpiles to: `ilogbf(x)` +#[inline] +pub fn ilogb(x: f32) -> i32 { + if x == 0.0 { + i32::MIN + } else if x.is_infinite() { + i32::MAX + } else { + x.abs().log2().floor() as i32 + } +} + +/// Error function. Transpiles to: `erff(x)` +#[inline] +pub fn erf(x: f32) -> f32 { + // Approximation using Horner form + let a1 = 0.254829592_f32; + let a2 = -0.284496736_f32; + let a3 = 1.421413741_f32; + let a4 = -1.453152027_f32; + let a5 = 1.061405429_f32; + let p = 0.3275911_f32; + + let sign = if x < 0.0 { -1.0 } else { 1.0 }; + let x = x.abs(); + let t = 1.0 / (1.0 + p * x); + let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp(); + sign * y +} + +/// Complementary error function. Transpiles to: `erfcf(x)` +#[inline] +pub fn erfc(x: f32) -> f32 { + 1.0 - erf(x) +} + +// ============================================================================ +// Classification and Comparison Functions +// ============================================================================ + +/// Check if NaN. Transpiles to: `isnan(x)` +#[inline] +pub fn is_nan(x: f32) -> bool { + x.is_nan() +} + +/// Check if infinite. Transpiles to: `isinf(x)` +#[inline] +pub fn is_infinite(x: f32) -> bool { + x.is_infinite() +} + +/// Check if finite. Transpiles to: `isfinite(x)` +#[inline] +pub fn is_finite(x: f32) -> bool { + x.is_finite() +} + +/// Check if normal. Transpiles to: `isnormal(x)` +#[inline] +pub fn is_normal(x: f32) -> bool { + x.is_normal() +} + +/// Check sign bit. Transpiles to: `signbit(x)` +#[inline] +pub fn signbit(x: f32) -> bool { + x.is_sign_negative() +} + +/// Next representable value. Transpiles to: `nextafterf(x, y)` +#[inline] +pub fn nextafter(x: f32, y: f32) -> f32 { + if x == y { + y + } else if y > x { + f32::from_bits(x.to_bits() + 1) + } else { + f32::from_bits(x.to_bits() - 1) + } +} + +/// Floating-point difference. Transpiles to: `fdimf(x, y)` +#[inline] +pub fn fdim(x: f32, y: f32) -> f32 { + if x > y { x - y } else { 0.0 } +} + +// ============================================================================ +// Warp-Level Operations +// ============================================================================ + +/// Get active thread mask. Transpiles to: `__activemask()` +#[inline] +pub fn warp_active_mask() -> u32 { + 1 // CPU fallback: only one thread active +} + +/// Warp ballot. Transpiles to: `__ballot_sync(mask, predicate)` +#[inline] +pub fn warp_ballot(_mask: u32, predicate: bool) -> u32 { + if predicate { 1 } else { 0 } +} + +/// Warp all predicate. Transpiles to: `__all_sync(mask, predicate)` +#[inline] +pub fn warp_all(_mask: u32, predicate: bool) -> bool { + predicate +} + +/// Warp any predicate. Transpiles to: `__any_sync(mask, predicate)` +#[inline] +pub fn warp_any(_mask: u32, predicate: bool) -> bool { + predicate +} + +/// Warp shuffle. Transpiles to: `__shfl_sync(mask, val, lane)` +#[inline] +pub fn warp_shfl(_mask: u32, val: T, _lane: i32) -> T { + val // CPU fallback: return same value +} + +/// Warp shuffle up. Transpiles to: `__shfl_up_sync(mask, val, delta)` +#[inline] +pub fn warp_shfl_up(_mask: u32, val: T, _delta: u32) -> T { + val +} + +/// Warp shuffle down. Transpiles to: `__shfl_down_sync(mask, val, delta)` +#[inline] +pub fn warp_shfl_down(_mask: u32, val: T, _delta: u32) -> T { + val +} + +/// Warp shuffle XOR. Transpiles to: `__shfl_xor_sync(mask, val, lane_mask)` +#[inline] +pub fn warp_shfl_xor(_mask: u32, val: T, _lane_mask: i32) -> T { + val +} + +/// Warp reduce add. Transpiles to: `__reduce_add_sync(mask, val)` +#[inline] +pub fn warp_reduce_add(_mask: u32, val: i32) -> i32 { + val // CPU: single thread, no reduction needed +} + +/// Warp reduce min. Transpiles to: `__reduce_min_sync(mask, val)` +#[inline] +pub fn warp_reduce_min(_mask: u32, val: i32) -> i32 { + val +} + +/// Warp reduce max. Transpiles to: `__reduce_max_sync(mask, val)` +#[inline] +pub fn warp_reduce_max(_mask: u32, val: i32) -> i32 { + val +} + +/// Warp reduce AND. Transpiles to: `__reduce_and_sync(mask, val)` +#[inline] +pub fn warp_reduce_and(_mask: u32, val: u32) -> u32 { + val +} + +/// Warp reduce OR. Transpiles to: `__reduce_or_sync(mask, val)` +#[inline] +pub fn warp_reduce_or(_mask: u32, val: u32) -> u32 { + val +} + +/// Warp reduce XOR. Transpiles to: `__reduce_xor_sync(mask, val)` +#[inline] +pub fn warp_reduce_xor(_mask: u32, val: u32) -> u32 { + val +} + +/// Warp match any. Transpiles to: `__match_any_sync(mask, val)` +#[inline] +pub fn warp_match_any(_mask: u32, _val: u32) -> u32 { + 1 // CPU: single thread always matches itself +} + +/// Warp match all. Transpiles to: `__match_all_sync(mask, val, pred)` +#[inline] +pub fn warp_match_all(_mask: u32, _val: u32) -> (u32, bool) { + (1, true) // CPU: single thread, trivially all match +} + +// ============================================================================ +// Bit Manipulation Functions +// ============================================================================ + +/// Population count (count set bits). Transpiles to: `__popc(x)` +#[inline] +pub fn popc(x: u32) -> i32 { + x.count_ones() as i32 +} + +/// Population count (i32 version). +#[inline] +pub fn popcount(x: i32) -> i32 { + (x as u32).count_ones() as i32 +} + +/// Count leading zeros. Transpiles to: `__clz(x)` +#[inline] +pub fn clz(x: u32) -> i32 { + x.leading_zeros() as i32 +} + +/// Count leading zeros (i32 version). +#[inline] +pub fn leading_zeros(x: i32) -> i32 { + (x as u32).leading_zeros() as i32 +} + +/// Count trailing zeros. Transpiles to: `__ffs(x) - 1` +#[inline] +pub fn ctz(x: u32) -> i32 { + if x == 0 { 32 } else { x.trailing_zeros() as i32 } +} + +/// Count trailing zeros (i32 version). +#[inline] +pub fn trailing_zeros(x: i32) -> i32 { + if x == 0 { 32 } else { (x as u32).trailing_zeros() as i32 } +} + +/// Find first set bit (1-indexed, 0 if none). Transpiles to: `__ffs(x)` +#[inline] +pub fn ffs(x: u32) -> i32 { + if x == 0 { 0 } else { (x.trailing_zeros() + 1) as i32 } +} + +/// Bit reverse. Transpiles to: `__brev(x)` +#[inline] +pub fn brev(x: u32) -> u32 { + x.reverse_bits() +} + +/// Bit reverse (i32 version). +#[inline] +pub fn reverse_bits(x: i32) -> i32 { + (x as u32).reverse_bits() as i32 +} + +/// Byte permutation. Transpiles to: `__byte_perm(x, y, s)` +#[inline] +pub fn byte_perm(x: u32, y: u32, s: u32) -> u32 { + let bytes = [ + (x & 0xFF) as u8, + ((x >> 8) & 0xFF) as u8, + ((x >> 16) & 0xFF) as u8, + ((x >> 24) & 0xFF) as u8, + (y & 0xFF) as u8, + ((y >> 8) & 0xFF) as u8, + ((y >> 16) & 0xFF) as u8, + ((y >> 24) & 0xFF) as u8, + ]; + let b0 = bytes[(s & 0x7) as usize] as u32; + let b1 = bytes[((s >> 4) & 0x7) as usize] as u32; + let b2 = bytes[((s >> 8) & 0x7) as usize] as u32; + let b3 = bytes[((s >> 12) & 0x7) as usize] as u32; + b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) +} + +/// Funnel shift left. Transpiles to: `__funnelshift_l(lo, hi, shift)` +#[inline] +pub fn funnel_shift_left(lo: u32, hi: u32, shift: u32) -> u32 { + let shift = shift & 31; + if shift == 0 { + lo + } else { + (hi << shift) | (lo >> (32 - shift)) + } +} + +/// Funnel shift right. Transpiles to: `__funnelshift_r(lo, hi, shift)` +#[inline] +pub fn funnel_shift_right(lo: u32, hi: u32, shift: u32) -> u32 { + let shift = shift & 31; + if shift == 0 { + lo + } else { + (lo >> shift) | (hi << (32 - shift)) + } +} + +// ============================================================================ +// Memory Operations +// ============================================================================ + +/// Read-only cache load. Transpiles to: `__ldg(ptr)` +#[inline] +pub fn ldg(ptr: &T) -> T { + *ptr +} + +/// Load from global memory (alias for ldg). +#[inline] +pub fn load_global(ptr: &T) -> T { + *ptr +} + +/// Prefetch to L1 cache. Transpiles to: `__prefetch_l1(ptr)` +#[inline] +pub fn prefetch_l1(_ptr: &T) { + // CPU fallback: no-op +} + +/// Prefetch to L2 cache. Transpiles to: `__prefetch_l2(ptr)` +#[inline] +pub fn prefetch_l2(_ptr: &T) { + // CPU fallback: no-op +} + +// ============================================================================ +// Special Functions +// ============================================================================ + +/// Fast reciprocal. Transpiles to: `__frcp_rn(x)` +#[inline] +pub fn rcp(x: f32) -> f32 { + 1.0 / x +} + +/// Fast division. Transpiles to: `__fdividef(x, y)` +#[inline] +pub fn fast_div(x: f32, y: f32) -> f32 { + x / y +} + +/// Saturate to [0, 1]. Transpiles to: `__saturatef(x)` +#[inline] +pub fn saturate(x: f32) -> f32 { + x.clamp(0.0, 1.0) +} + +/// Clamp to [0, 1] (alias for saturate). +#[inline] +pub fn clamp_01(x: f32) -> f32 { + saturate(x) +} + +// ============================================================================ +// Clock and Timing +// ============================================================================ + +/// Read clock counter. Transpiles to: `clock()` +#[inline] +pub fn clock() -> u32 { + // CPU fallback: use std time + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as u32) + .unwrap_or(0) +} + +/// Read 64-bit clock counter. Transpiles to: `clock64()` +#[inline] +pub fn clock64() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos() as u64) + .unwrap_or(0) +} + +/// Nanosleep. Transpiles to: `__nanosleep(ns)` +#[inline] +pub fn nanosleep(ns: u32) { + std::thread::sleep(std::time::Duration::from_nanos(ns as u64)); } #[cfg(test)] @@ -151,5 +965,110 @@ mod tests { assert_eq!(block_dim_x(), 1); assert_eq!(block_dim_y(), 1); assert_eq!(grid_dim_x(), 1); + assert_eq!(warp_size(), 32); + } + + #[test] + fn test_math_functions() { + assert!((sqrt(4.0) - 2.0).abs() < 1e-6); + assert!((rsqrt(4.0) - 0.5).abs() < 1e-6); + assert!((sin(0.0)).abs() < 1e-6); + assert!((cos(0.0) - 1.0).abs() < 1e-6); + assert!((exp(0.0) - 1.0).abs() < 1e-6); + assert!((log(1.0)).abs() < 1e-6); + } + + #[test] + fn test_trigonometric_functions() { + let pi = std::f32::consts::PI; + assert!((sin(pi / 2.0) - 1.0).abs() < 1e-6); + assert!((cos(pi) + 1.0).abs() < 1e-6); + assert!((tan(0.0)).abs() < 1e-6); + assert!((asin(1.0) - pi / 2.0).abs() < 1e-6); + assert!((atan2(1.0, 1.0) - pi / 4.0).abs() < 1e-6); + } + + #[test] + fn test_hyperbolic_functions() { + assert!((sinh(0.0)).abs() < 1e-6); + assert!((cosh(0.0) - 1.0).abs() < 1e-6); + assert!((tanh(0.0)).abs() < 1e-6); + } + + #[test] + fn test_exponential_functions() { + assert!((exp2(3.0) - 8.0).abs() < 1e-6); + assert!((log2(8.0) - 3.0).abs() < 1e-6); + assert!((log10(100.0) - 2.0).abs() < 1e-6); + assert!((pow(2.0, 3.0) - 8.0).abs() < 1e-6); + } + + #[test] + fn test_classification_functions() { + assert!(is_nan(f32::NAN)); + assert!(!is_nan(1.0)); + assert!(is_infinite(f32::INFINITY)); + assert!(!is_infinite(1.0)); + assert!(is_finite(1.0)); + assert!(!is_finite(f32::INFINITY)); + } + + #[test] + fn test_bit_manipulation() { + assert_eq!(popc(0b1010_1010), 4); + assert_eq!(clz(1u32), 31); + assert_eq!(clz(0x8000_0000u32), 0); + assert_eq!(ctz(0b1000), 3); + assert_eq!(ffs(0b1000), 4); + assert_eq!(brev(1u32), 0x8000_0000); + } + + #[test] + fn test_warp_operations() { + assert_eq!(warp_active_mask(), 1); + assert_eq!(warp_ballot(0xFFFF_FFFF, true), 1); + assert!(warp_all(0xFFFF_FFFF, true)); + assert!(warp_any(0xFFFF_FFFF, true)); + assert_eq!(warp_reduce_add(0xFFFF_FFFF, 5), 5); + } + + #[test] + fn test_special_functions() { + assert!((rcp(2.0) - 0.5).abs() < 1e-6); + assert!((fast_div(10.0, 2.0) - 5.0).abs() < 1e-6); + assert_eq!(saturate(-1.0), 0.0); + assert_eq!(saturate(0.5), 0.5); + assert_eq!(saturate(2.0), 1.0); + } + + #[test] + fn test_atomic_operations() { + let mut val = 10; + assert_eq!(atomic_add(&mut val, 5), 10); + assert_eq!(val, 15); + + let mut val = 10; + assert_eq!(atomic_sub(&mut val, 3), 10); + assert_eq!(val, 7); + + let mut val = 10; + assert_eq!(atomic_cas(&mut val, 10, 20), 10); + assert_eq!(val, 20); + } + + #[test] + fn test_funnel_shift() { + assert_eq!(funnel_shift_left(0xFFFF_0000, 0x0000_FFFF, 16), 0xFFFF_FFFF); + assert_eq!(funnel_shift_right(0xFFFF_0000, 0x0000_FFFF, 16), 0xFFFF_FFFF); + } + + #[test] + fn test_byte_perm() { + let x = 0x04030201u32; + let y = 0x08070605u32; + // Select bytes 0, 1, 2, 3 from x + assert_eq!(byte_perm(x, y, 0x3210), 0x04030201); + // Select bytes 4, 5, 6, 7 from y + assert_eq!(byte_perm(x, y, 0x7654), 0x08070605); } } diff --git a/crates/ringkernel-cuda-codegen/src/intrinsics.rs b/crates/ringkernel-cuda-codegen/src/intrinsics.rs index 5b17b00..3328ddb 100644 --- a/crates/ringkernel-cuda-codegen/src/intrinsics.rs +++ b/crates/ringkernel-cuda-codegen/src/intrinsics.rs @@ -8,78 +8,324 @@ use std::collections::HashMap; /// GPU intrinsic operations. #[derive(Debug, Clone, PartialEq)] pub enum GpuIntrinsic { - /// Thread synchronization. + // === Synchronization === + /// Thread synchronization within a block. SyncThreads, - /// Thread fence (memory ordering). + /// Thread fence (memory ordering across device). ThreadFence, + /// Thread fence within block. ThreadFenceBlock, + /// Thread fence across system. ThreadFenceSystem, + /// Synchronize threads with predicate. + SyncThreadsCount, + /// Synchronize threads with AND predicate. + SyncThreadsAnd, + /// Synchronize threads with OR predicate. + SyncThreadsOr, - /// Atomic operations. + // === Atomic Operations (Integer) === + /// Atomic add. AtomicAdd, + /// Atomic subtract. AtomicSub, + /// Atomic minimum. AtomicMin, + /// Atomic maximum. AtomicMax, + /// Atomic exchange. AtomicExch, + /// Atomic compare-and-swap. AtomicCas, + /// Atomic bitwise AND. + AtomicAnd, + /// Atomic bitwise OR. + AtomicOr, + /// Atomic bitwise XOR. + AtomicXor, + /// Atomic increment (with wrap). + AtomicInc, + /// Atomic decrement (with wrap). + AtomicDec, - /// Math functions. + // === Basic Math Functions === + /// Square root. Sqrt, + /// Reciprocal square root. Rsqrt, + /// Absolute value (integer). Abs, + /// Absolute value (floating point). Fabs, + /// Floor. Floor, + /// Ceiling. Ceil, + /// Round to nearest. Round, + /// Truncate toward zero. + Trunc, + /// Fused multiply-add. + Fma, + /// Minimum. + Min, + /// Maximum. + Max, + /// Floating-point modulo. + Fmod, + /// Remainder. + Remainder, + /// Copy sign. + Copysign, + /// Cube root. + Cbrt, + /// Hypotenuse. + Hypot, + + // === Trigonometric Functions === + /// Sine. Sin, + /// Cosine. Cos, + /// Tangent. Tan, + /// Arcsine. + Asin, + /// Arccosine. + Acos, + /// Arctangent. + Atan, + /// Two-argument arctangent. + Atan2, + /// Sine and cosine (combined). + Sincos, + /// Sine of pi*x. + Sinpi, + /// Cosine of pi*x. + Cospi, + + // === Hyperbolic Functions === + /// Hyperbolic sine. + Sinh, + /// Hyperbolic cosine. + Cosh, + /// Hyperbolic tangent. + Tanh, + /// Inverse hyperbolic sine. + Asinh, + /// Inverse hyperbolic cosine. + Acosh, + /// Inverse hyperbolic tangent. + Atanh, + + // === Exponential and Logarithmic Functions === + /// Exponential (base e). Exp, + /// Exponential (base 2). + Exp2, + /// Exponential (base 10). + Exp10, + /// exp(x) - 1 (accurate for small x). + Expm1, + /// Natural logarithm (base e). Log, + /// Logarithm (base 2). + Log2, + /// Logarithm (base 10). + Log10, + /// log(1 + x) (accurate for small x). + Log1p, + /// Power. Pow, - Fma, - Min, - Max, + /// Load exponent. + Ldexp, + /// Scale by power of 2. + Scalbn, + /// Extract exponent. + Ilogb, + /// Logarithm of gamma function. + Lgamma, + /// Gamma function. + Tgamma, + /// Error function. + Erf, + /// Complementary error function. + Erfc, + /// Inverse error function. + Erfinv, + /// Inverse complementary error function. + Erfcinv, - /// Warp-level operations. + // === Classification and Comparison === + /// Check if NaN. + Isnan, + /// Check if infinite. + Isinf, + /// Check if finite. + Isfinite, + /// Check if normal. + Isnormal, + /// Check sign bit. + Signbit, + /// Next representable value. + Nextafter, + /// Floating-point difference. + Fdim, + /// Not-a-Number. + Nan, + + // === Warp-Level Operations === + /// Warp shuffle. WarpShfl, + /// Warp shuffle up. WarpShflUp, + /// Warp shuffle down. WarpShflDown, + /// Warp shuffle XOR. WarpShflXor, + /// Get active thread mask. WarpActiveMask, + /// Warp ballot. WarpBallot, + /// Warp all predicate. WarpAll, + /// Warp any predicate. WarpAny, + /// Warp match any. + WarpMatchAny, + /// Warp match all. + WarpMatchAll, + /// Warp reduce add. + WarpReduceAdd, + /// Warp reduce min. + WarpReduceMin, + /// Warp reduce max. + WarpReduceMax, + /// Warp reduce AND. + WarpReduceAnd, + /// Warp reduce OR. + WarpReduceOr, + /// Warp reduce XOR. + WarpReduceXor, + + // === Bit Manipulation === + /// Population count (count set bits). + Popc, + /// Count leading zeros. + Clz, + /// Count trailing zeros (via ffs). + Ctz, + /// Find first set bit. + Ffs, + /// Bit reverse. + Brev, + /// Byte permute. + BytePerm, + /// Funnel shift left. + FunnelShiftLeft, + /// Funnel shift right. + FunnelShiftRight, - /// CUDA thread/block indices. + // === Memory Operations === + /// Read-only cache load. + Ldg, + /// Prefetch L1. + PrefetchL1, + /// Prefetch L2. + PrefetchL2, + + // === Special Functions === + /// Reciprocal. + Rcp, + /// Division (fast). + Fdividef, + /// Saturate to [0,1]. + Saturate, + /// Bessel J0. + J0, + /// Bessel J1. + J1, + /// Bessel Jn. + Jn, + /// Bessel Y0. + Y0, + /// Bessel Y1. + Y1, + /// Bessel Yn. + Yn, + /// Normal CDF. + Normcdf, + /// Inverse normal CDF. + Normcdfinv, + /// Cylindrical Bessel I0. + CylBesselI0, + /// Cylindrical Bessel I1. + CylBesselI1, + + // === CUDA Thread/Block Indices === + /// Thread index X. ThreadIdxX, + /// Thread index Y. ThreadIdxY, + /// Thread index Z. ThreadIdxZ, + /// Block index X. BlockIdxX, + /// Block index Y. BlockIdxY, + /// Block index Z. BlockIdxZ, + /// Block dimension X. BlockDimX, + /// Block dimension Y. BlockDimY, + /// Block dimension Z. BlockDimZ, + /// Grid dimension X. GridDimX, + /// Grid dimension Y. GridDimY, + /// Grid dimension Z. GridDimZ, + /// Warp size (always 32). + WarpSize, + + // === Clock and Timing === + /// Read clock counter. + Clock, + /// Read 64-bit clock counter. + Clock64, + /// Nanosleep. + Nanosleep, } impl GpuIntrinsic { /// Convert to CUDA function/intrinsic name. pub fn to_cuda_string(&self) -> &'static str { match self { + // Synchronization GpuIntrinsic::SyncThreads => "__syncthreads()", GpuIntrinsic::ThreadFence => "__threadfence()", GpuIntrinsic::ThreadFenceBlock => "__threadfence_block()", GpuIntrinsic::ThreadFenceSystem => "__threadfence_system()", + GpuIntrinsic::SyncThreadsCount => "__syncthreads_count", + GpuIntrinsic::SyncThreadsAnd => "__syncthreads_and", + GpuIntrinsic::SyncThreadsOr => "__syncthreads_or", + + // Atomic operations GpuIntrinsic::AtomicAdd => "atomicAdd", GpuIntrinsic::AtomicSub => "atomicSub", GpuIntrinsic::AtomicMin => "atomicMin", GpuIntrinsic::AtomicMax => "atomicMax", GpuIntrinsic::AtomicExch => "atomicExch", GpuIntrinsic::AtomicCas => "atomicCAS", + GpuIntrinsic::AtomicAnd => "atomicAnd", + GpuIntrinsic::AtomicOr => "atomicOr", + GpuIntrinsic::AtomicXor => "atomicXor", + GpuIntrinsic::AtomicInc => "atomicInc", + GpuIntrinsic::AtomicDec => "atomicDec", + + // Basic math GpuIntrinsic::Sqrt => "sqrtf", GpuIntrinsic::Rsqrt => "rsqrtf", GpuIntrinsic::Abs => "abs", @@ -87,15 +333,67 @@ impl GpuIntrinsic { GpuIntrinsic::Floor => "floorf", GpuIntrinsic::Ceil => "ceilf", GpuIntrinsic::Round => "roundf", + GpuIntrinsic::Trunc => "truncf", + GpuIntrinsic::Fma => "fmaf", + GpuIntrinsic::Min => "fminf", + GpuIntrinsic::Max => "fmaxf", + GpuIntrinsic::Fmod => "fmodf", + GpuIntrinsic::Remainder => "remainderf", + GpuIntrinsic::Copysign => "copysignf", + GpuIntrinsic::Cbrt => "cbrtf", + GpuIntrinsic::Hypot => "hypotf", + + // Trigonometric GpuIntrinsic::Sin => "sinf", GpuIntrinsic::Cos => "cosf", GpuIntrinsic::Tan => "tanf", + GpuIntrinsic::Asin => "asinf", + GpuIntrinsic::Acos => "acosf", + GpuIntrinsic::Atan => "atanf", + GpuIntrinsic::Atan2 => "atan2f", + GpuIntrinsic::Sincos => "sincosf", + GpuIntrinsic::Sinpi => "sinpif", + GpuIntrinsic::Cospi => "cospif", + + // Hyperbolic + GpuIntrinsic::Sinh => "sinhf", + GpuIntrinsic::Cosh => "coshf", + GpuIntrinsic::Tanh => "tanhf", + GpuIntrinsic::Asinh => "asinhf", + GpuIntrinsic::Acosh => "acoshf", + GpuIntrinsic::Atanh => "atanhf", + + // Exponential and logarithmic GpuIntrinsic::Exp => "expf", + GpuIntrinsic::Exp2 => "exp2f", + GpuIntrinsic::Exp10 => "exp10f", + GpuIntrinsic::Expm1 => "expm1f", GpuIntrinsic::Log => "logf", + GpuIntrinsic::Log2 => "log2f", + GpuIntrinsic::Log10 => "log10f", + GpuIntrinsic::Log1p => "log1pf", GpuIntrinsic::Pow => "powf", - GpuIntrinsic::Fma => "fmaf", - GpuIntrinsic::Min => "fminf", - GpuIntrinsic::Max => "fmaxf", + GpuIntrinsic::Ldexp => "ldexpf", + GpuIntrinsic::Scalbn => "scalbnf", + GpuIntrinsic::Ilogb => "ilogbf", + GpuIntrinsic::Lgamma => "lgammaf", + GpuIntrinsic::Tgamma => "tgammaf", + GpuIntrinsic::Erf => "erff", + GpuIntrinsic::Erfc => "erfcf", + GpuIntrinsic::Erfinv => "erfinvf", + GpuIntrinsic::Erfcinv => "erfcinvf", + + // Classification and comparison + GpuIntrinsic::Isnan => "isnan", + GpuIntrinsic::Isinf => "isinf", + GpuIntrinsic::Isfinite => "isfinite", + GpuIntrinsic::Isnormal => "isnormal", + GpuIntrinsic::Signbit => "signbit", + GpuIntrinsic::Nextafter => "nextafterf", + GpuIntrinsic::Fdim => "fdimf", + GpuIntrinsic::Nan => "nanf", + + // Warp-level operations GpuIntrinsic::WarpShfl => "__shfl_sync", GpuIntrinsic::WarpShflUp => "__shfl_up_sync", GpuIntrinsic::WarpShflDown => "__shfl_down_sync", @@ -104,6 +402,46 @@ impl GpuIntrinsic { GpuIntrinsic::WarpBallot => "__ballot_sync", GpuIntrinsic::WarpAll => "__all_sync", GpuIntrinsic::WarpAny => "__any_sync", + GpuIntrinsic::WarpMatchAny => "__match_any_sync", + GpuIntrinsic::WarpMatchAll => "__match_all_sync", + GpuIntrinsic::WarpReduceAdd => "__reduce_add_sync", + GpuIntrinsic::WarpReduceMin => "__reduce_min_sync", + GpuIntrinsic::WarpReduceMax => "__reduce_max_sync", + GpuIntrinsic::WarpReduceAnd => "__reduce_and_sync", + GpuIntrinsic::WarpReduceOr => "__reduce_or_sync", + GpuIntrinsic::WarpReduceXor => "__reduce_xor_sync", + + // Bit manipulation + GpuIntrinsic::Popc => "__popc", + GpuIntrinsic::Clz => "__clz", + GpuIntrinsic::Ctz => "__ffs", // ffs returns 1 + ctz, but commonly used + GpuIntrinsic::Ffs => "__ffs", + GpuIntrinsic::Brev => "__brev", + GpuIntrinsic::BytePerm => "__byte_perm", + GpuIntrinsic::FunnelShiftLeft => "__funnelshift_l", + GpuIntrinsic::FunnelShiftRight => "__funnelshift_r", + + // Memory operations + GpuIntrinsic::Ldg => "__ldg", + GpuIntrinsic::PrefetchL1 => "__prefetch_l1", + GpuIntrinsic::PrefetchL2 => "__prefetch_l2", + + // Special functions + GpuIntrinsic::Rcp => "__frcp_rn", + GpuIntrinsic::Fdividef => "__fdividef", + GpuIntrinsic::Saturate => "__saturatef", + GpuIntrinsic::J0 => "j0f", + GpuIntrinsic::J1 => "j1f", + GpuIntrinsic::Jn => "jnf", + GpuIntrinsic::Y0 => "y0f", + GpuIntrinsic::Y1 => "y1f", + GpuIntrinsic::Yn => "ynf", + GpuIntrinsic::Normcdf => "normcdff", + GpuIntrinsic::Normcdfinv => "normcdfinvf", + GpuIntrinsic::CylBesselI0 => "cyl_bessel_i0f", + GpuIntrinsic::CylBesselI1 => "cyl_bessel_i1f", + + // Thread/block indices GpuIntrinsic::ThreadIdxX => "threadIdx.x", GpuIntrinsic::ThreadIdxY => "threadIdx.y", GpuIntrinsic::ThreadIdxZ => "threadIdx.z", @@ -116,6 +454,214 @@ impl GpuIntrinsic { GpuIntrinsic::GridDimX => "gridDim.x", GpuIntrinsic::GridDimY => "gridDim.y", GpuIntrinsic::GridDimZ => "gridDim.z", + GpuIntrinsic::WarpSize => "warpSize", + + // Clock and timing + GpuIntrinsic::Clock => "clock()", + GpuIntrinsic::Clock64 => "clock64()", + GpuIntrinsic::Nanosleep => "__nanosleep", + } + } + + /// Check if this intrinsic is a value (no parentheses needed). + pub fn is_value_intrinsic(&self) -> bool { + matches!( + self, + GpuIntrinsic::ThreadIdxX + | GpuIntrinsic::ThreadIdxY + | GpuIntrinsic::ThreadIdxZ + | GpuIntrinsic::BlockIdxX + | GpuIntrinsic::BlockIdxY + | GpuIntrinsic::BlockIdxZ + | GpuIntrinsic::BlockDimX + | GpuIntrinsic::BlockDimY + | GpuIntrinsic::BlockDimZ + | GpuIntrinsic::GridDimX + | GpuIntrinsic::GridDimY + | GpuIntrinsic::GridDimZ + | GpuIntrinsic::WarpSize + ) + } + + /// Check if this intrinsic is a zero-argument function (ends with ()). + pub fn is_zero_arg_function(&self) -> bool { + matches!( + self, + GpuIntrinsic::SyncThreads + | GpuIntrinsic::ThreadFence + | GpuIntrinsic::ThreadFenceBlock + | GpuIntrinsic::ThreadFenceSystem + | GpuIntrinsic::WarpActiveMask + | GpuIntrinsic::Clock + | GpuIntrinsic::Clock64 + ) + } + + /// Check if this intrinsic requires a mask argument (warp operations). + pub fn requires_mask(&self) -> bool { + matches!( + self, + GpuIntrinsic::WarpShfl + | GpuIntrinsic::WarpShflUp + | GpuIntrinsic::WarpShflDown + | GpuIntrinsic::WarpShflXor + | GpuIntrinsic::WarpBallot + | GpuIntrinsic::WarpAll + | GpuIntrinsic::WarpAny + | GpuIntrinsic::WarpMatchAny + | GpuIntrinsic::WarpMatchAll + | GpuIntrinsic::WarpReduceAdd + | GpuIntrinsic::WarpReduceMin + | GpuIntrinsic::WarpReduceMax + | GpuIntrinsic::WarpReduceAnd + | GpuIntrinsic::WarpReduceOr + | GpuIntrinsic::WarpReduceXor + ) + } + + /// Get the category of this intrinsic for documentation purposes. + pub fn category(&self) -> &'static str { + match self { + GpuIntrinsic::SyncThreads + | GpuIntrinsic::ThreadFence + | GpuIntrinsic::ThreadFenceBlock + | GpuIntrinsic::ThreadFenceSystem + | GpuIntrinsic::SyncThreadsCount + | GpuIntrinsic::SyncThreadsAnd + | GpuIntrinsic::SyncThreadsOr => "synchronization", + + GpuIntrinsic::AtomicAdd + | GpuIntrinsic::AtomicSub + | GpuIntrinsic::AtomicMin + | GpuIntrinsic::AtomicMax + | GpuIntrinsic::AtomicExch + | GpuIntrinsic::AtomicCas + | GpuIntrinsic::AtomicAnd + | GpuIntrinsic::AtomicOr + | GpuIntrinsic::AtomicXor + | GpuIntrinsic::AtomicInc + | GpuIntrinsic::AtomicDec => "atomic", + + GpuIntrinsic::Sqrt + | GpuIntrinsic::Rsqrt + | GpuIntrinsic::Abs + | GpuIntrinsic::Fabs + | GpuIntrinsic::Floor + | GpuIntrinsic::Ceil + | GpuIntrinsic::Round + | GpuIntrinsic::Trunc + | GpuIntrinsic::Fma + | GpuIntrinsic::Min + | GpuIntrinsic::Max + | GpuIntrinsic::Fmod + | GpuIntrinsic::Remainder + | GpuIntrinsic::Copysign + | GpuIntrinsic::Cbrt + | GpuIntrinsic::Hypot => "math", + + GpuIntrinsic::Sin + | GpuIntrinsic::Cos + | GpuIntrinsic::Tan + | GpuIntrinsic::Asin + | GpuIntrinsic::Acos + | GpuIntrinsic::Atan + | GpuIntrinsic::Atan2 + | GpuIntrinsic::Sincos + | GpuIntrinsic::Sinpi + | GpuIntrinsic::Cospi => "trigonometric", + + GpuIntrinsic::Sinh + | GpuIntrinsic::Cosh + | GpuIntrinsic::Tanh + | GpuIntrinsic::Asinh + | GpuIntrinsic::Acosh + | GpuIntrinsic::Atanh => "hyperbolic", + + GpuIntrinsic::Exp + | GpuIntrinsic::Exp2 + | GpuIntrinsic::Exp10 + | GpuIntrinsic::Expm1 + | GpuIntrinsic::Log + | GpuIntrinsic::Log2 + | GpuIntrinsic::Log10 + | GpuIntrinsic::Log1p + | GpuIntrinsic::Pow + | GpuIntrinsic::Ldexp + | GpuIntrinsic::Scalbn + | GpuIntrinsic::Ilogb + | GpuIntrinsic::Lgamma + | GpuIntrinsic::Tgamma + | GpuIntrinsic::Erf + | GpuIntrinsic::Erfc + | GpuIntrinsic::Erfinv + | GpuIntrinsic::Erfcinv => "exponential", + + GpuIntrinsic::Isnan + | GpuIntrinsic::Isinf + | GpuIntrinsic::Isfinite + | GpuIntrinsic::Isnormal + | GpuIntrinsic::Signbit + | GpuIntrinsic::Nextafter + | GpuIntrinsic::Fdim + | GpuIntrinsic::Nan => "classification", + + GpuIntrinsic::WarpShfl + | GpuIntrinsic::WarpShflUp + | GpuIntrinsic::WarpShflDown + | GpuIntrinsic::WarpShflXor + | GpuIntrinsic::WarpActiveMask + | GpuIntrinsic::WarpBallot + | GpuIntrinsic::WarpAll + | GpuIntrinsic::WarpAny + | GpuIntrinsic::WarpMatchAny + | GpuIntrinsic::WarpMatchAll + | GpuIntrinsic::WarpReduceAdd + | GpuIntrinsic::WarpReduceMin + | GpuIntrinsic::WarpReduceMax + | GpuIntrinsic::WarpReduceAnd + | GpuIntrinsic::WarpReduceOr + | GpuIntrinsic::WarpReduceXor => "warp", + + GpuIntrinsic::Popc + | GpuIntrinsic::Clz + | GpuIntrinsic::Ctz + | GpuIntrinsic::Ffs + | GpuIntrinsic::Brev + | GpuIntrinsic::BytePerm + | GpuIntrinsic::FunnelShiftLeft + | GpuIntrinsic::FunnelShiftRight => "bit", + + GpuIntrinsic::Ldg | GpuIntrinsic::PrefetchL1 | GpuIntrinsic::PrefetchL2 => "memory", + + GpuIntrinsic::Rcp + | GpuIntrinsic::Fdividef + | GpuIntrinsic::Saturate + | GpuIntrinsic::J0 + | GpuIntrinsic::J1 + | GpuIntrinsic::Jn + | GpuIntrinsic::Y0 + | GpuIntrinsic::Y1 + | GpuIntrinsic::Yn + | GpuIntrinsic::Normcdf + | GpuIntrinsic::Normcdfinv + | GpuIntrinsic::CylBesselI0 + | GpuIntrinsic::CylBesselI1 => "special", + + GpuIntrinsic::ThreadIdxX + | GpuIntrinsic::ThreadIdxY + | GpuIntrinsic::ThreadIdxZ + | GpuIntrinsic::BlockIdxX + | GpuIntrinsic::BlockIdxY + | GpuIntrinsic::BlockIdxZ + | GpuIntrinsic::BlockDimX + | GpuIntrinsic::BlockDimY + | GpuIntrinsic::BlockDimZ + | GpuIntrinsic::GridDimX + | GpuIntrinsic::GridDimY + | GpuIntrinsic::GridDimZ + | GpuIntrinsic::WarpSize => "index", + + GpuIntrinsic::Clock | GpuIntrinsic::Clock64 | GpuIntrinsic::Nanosleep => "timing", } } } @@ -137,45 +683,176 @@ impl IntrinsicRegistry { pub fn new() -> Self { let mut mappings = HashMap::new(); - // Synchronization + // === Synchronization === mappings.insert("sync_threads".to_string(), GpuIntrinsic::SyncThreads); mappings.insert("thread_fence".to_string(), GpuIntrinsic::ThreadFence); - mappings.insert( - "thread_fence_block".to_string(), - GpuIntrinsic::ThreadFenceBlock, - ); - mappings.insert( - "thread_fence_system".to_string(), - GpuIntrinsic::ThreadFenceSystem, - ); + mappings.insert("thread_fence_block".to_string(), GpuIntrinsic::ThreadFenceBlock); + mappings.insert("thread_fence_system".to_string(), GpuIntrinsic::ThreadFenceSystem); + mappings.insert("sync_threads_count".to_string(), GpuIntrinsic::SyncThreadsCount); + mappings.insert("sync_threads_and".to_string(), GpuIntrinsic::SyncThreadsAnd); + mappings.insert("sync_threads_or".to_string(), GpuIntrinsic::SyncThreadsOr); - // Atomics (common naming) + // === Atomic operations === mappings.insert("atomic_add".to_string(), GpuIntrinsic::AtomicAdd); mappings.insert("atomic_sub".to_string(), GpuIntrinsic::AtomicSub); mappings.insert("atomic_min".to_string(), GpuIntrinsic::AtomicMin); mappings.insert("atomic_max".to_string(), GpuIntrinsic::AtomicMax); mappings.insert("atomic_exchange".to_string(), GpuIntrinsic::AtomicExch); + mappings.insert("atomic_exch".to_string(), GpuIntrinsic::AtomicExch); mappings.insert("atomic_cas".to_string(), GpuIntrinsic::AtomicCas); + mappings.insert("atomic_compare_swap".to_string(), GpuIntrinsic::AtomicCas); + mappings.insert("atomic_and".to_string(), GpuIntrinsic::AtomicAnd); + mappings.insert("atomic_or".to_string(), GpuIntrinsic::AtomicOr); + mappings.insert("atomic_xor".to_string(), GpuIntrinsic::AtomicXor); + mappings.insert("atomic_inc".to_string(), GpuIntrinsic::AtomicInc); + mappings.insert("atomic_dec".to_string(), GpuIntrinsic::AtomicDec); - // Math functions (Rust std naming) + // === Basic math functions === mappings.insert("sqrt".to_string(), GpuIntrinsic::Sqrt); + mappings.insert("rsqrt".to_string(), GpuIntrinsic::Rsqrt); mappings.insert("abs".to_string(), GpuIntrinsic::Fabs); + mappings.insert("fabs".to_string(), GpuIntrinsic::Fabs); mappings.insert("floor".to_string(), GpuIntrinsic::Floor); mappings.insert("ceil".to_string(), GpuIntrinsic::Ceil); mappings.insert("round".to_string(), GpuIntrinsic::Round); + mappings.insert("trunc".to_string(), GpuIntrinsic::Trunc); + mappings.insert("mul_add".to_string(), GpuIntrinsic::Fma); + mappings.insert("fma".to_string(), GpuIntrinsic::Fma); + mappings.insert("min".to_string(), GpuIntrinsic::Min); + mappings.insert("max".to_string(), GpuIntrinsic::Max); + mappings.insert("fmin".to_string(), GpuIntrinsic::Min); + mappings.insert("fmax".to_string(), GpuIntrinsic::Max); + mappings.insert("fmod".to_string(), GpuIntrinsic::Fmod); + mappings.insert("remainder".to_string(), GpuIntrinsic::Remainder); + mappings.insert("copysign".to_string(), GpuIntrinsic::Copysign); + mappings.insert("cbrt".to_string(), GpuIntrinsic::Cbrt); + mappings.insert("hypot".to_string(), GpuIntrinsic::Hypot); + + // === Trigonometric functions === mappings.insert("sin".to_string(), GpuIntrinsic::Sin); mappings.insert("cos".to_string(), GpuIntrinsic::Cos); mappings.insert("tan".to_string(), GpuIntrinsic::Tan); + mappings.insert("asin".to_string(), GpuIntrinsic::Asin); + mappings.insert("acos".to_string(), GpuIntrinsic::Acos); + mappings.insert("atan".to_string(), GpuIntrinsic::Atan); + mappings.insert("atan2".to_string(), GpuIntrinsic::Atan2); + mappings.insert("sincos".to_string(), GpuIntrinsic::Sincos); + mappings.insert("sinpi".to_string(), GpuIntrinsic::Sinpi); + mappings.insert("cospi".to_string(), GpuIntrinsic::Cospi); + + // === Hyperbolic functions === + mappings.insert("sinh".to_string(), GpuIntrinsic::Sinh); + mappings.insert("cosh".to_string(), GpuIntrinsic::Cosh); + mappings.insert("tanh".to_string(), GpuIntrinsic::Tanh); + mappings.insert("asinh".to_string(), GpuIntrinsic::Asinh); + mappings.insert("acosh".to_string(), GpuIntrinsic::Acosh); + mappings.insert("atanh".to_string(), GpuIntrinsic::Atanh); + + // === Exponential and logarithmic === mappings.insert("exp".to_string(), GpuIntrinsic::Exp); + mappings.insert("exp2".to_string(), GpuIntrinsic::Exp2); + mappings.insert("exp10".to_string(), GpuIntrinsic::Exp10); + mappings.insert("expm1".to_string(), GpuIntrinsic::Expm1); mappings.insert("ln".to_string(), GpuIntrinsic::Log); mappings.insert("log".to_string(), GpuIntrinsic::Log); + mappings.insert("log2".to_string(), GpuIntrinsic::Log2); + mappings.insert("log10".to_string(), GpuIntrinsic::Log10); + mappings.insert("log1p".to_string(), GpuIntrinsic::Log1p); mappings.insert("powf".to_string(), GpuIntrinsic::Pow); mappings.insert("powi".to_string(), GpuIntrinsic::Pow); - mappings.insert("mul_add".to_string(), GpuIntrinsic::Fma); - mappings.insert("min".to_string(), GpuIntrinsic::Min); - mappings.insert("max".to_string(), GpuIntrinsic::Max); + mappings.insert("pow".to_string(), GpuIntrinsic::Pow); + mappings.insert("ldexp".to_string(), GpuIntrinsic::Ldexp); + mappings.insert("scalbn".to_string(), GpuIntrinsic::Scalbn); + mappings.insert("ilogb".to_string(), GpuIntrinsic::Ilogb); + mappings.insert("lgamma".to_string(), GpuIntrinsic::Lgamma); + mappings.insert("tgamma".to_string(), GpuIntrinsic::Tgamma); + mappings.insert("gamma".to_string(), GpuIntrinsic::Tgamma); + mappings.insert("erf".to_string(), GpuIntrinsic::Erf); + mappings.insert("erfc".to_string(), GpuIntrinsic::Erfc); + mappings.insert("erfinv".to_string(), GpuIntrinsic::Erfinv); + mappings.insert("erfcinv".to_string(), GpuIntrinsic::Erfcinv); - // CUDA thread/block indices (function-style access in Rust DSL) + // === Classification and comparison === + mappings.insert("is_nan".to_string(), GpuIntrinsic::Isnan); + mappings.insert("isnan".to_string(), GpuIntrinsic::Isnan); + mappings.insert("is_infinite".to_string(), GpuIntrinsic::Isinf); + mappings.insert("isinf".to_string(), GpuIntrinsic::Isinf); + mappings.insert("is_finite".to_string(), GpuIntrinsic::Isfinite); + mappings.insert("isfinite".to_string(), GpuIntrinsic::Isfinite); + mappings.insert("is_normal".to_string(), GpuIntrinsic::Isnormal); + mappings.insert("isnormal".to_string(), GpuIntrinsic::Isnormal); + mappings.insert("is_sign_negative".to_string(), GpuIntrinsic::Signbit); + mappings.insert("signbit".to_string(), GpuIntrinsic::Signbit); + mappings.insert("nextafter".to_string(), GpuIntrinsic::Nextafter); + mappings.insert("fdim".to_string(), GpuIntrinsic::Fdim); + mappings.insert("nan".to_string(), GpuIntrinsic::Nan); + + // === Warp operations === + mappings.insert("warp_shfl".to_string(), GpuIntrinsic::WarpShfl); + mappings.insert("warp_shuffle".to_string(), GpuIntrinsic::WarpShfl); + mappings.insert("warp_shfl_up".to_string(), GpuIntrinsic::WarpShflUp); + mappings.insert("warp_shuffle_up".to_string(), GpuIntrinsic::WarpShflUp); + mappings.insert("warp_shfl_down".to_string(), GpuIntrinsic::WarpShflDown); + mappings.insert("warp_shuffle_down".to_string(), GpuIntrinsic::WarpShflDown); + mappings.insert("warp_shfl_xor".to_string(), GpuIntrinsic::WarpShflXor); + mappings.insert("warp_shuffle_xor".to_string(), GpuIntrinsic::WarpShflXor); + mappings.insert("warp_active_mask".to_string(), GpuIntrinsic::WarpActiveMask); + mappings.insert("active_mask".to_string(), GpuIntrinsic::WarpActiveMask); + mappings.insert("warp_ballot".to_string(), GpuIntrinsic::WarpBallot); + mappings.insert("ballot".to_string(), GpuIntrinsic::WarpBallot); + mappings.insert("warp_all".to_string(), GpuIntrinsic::WarpAll); + mappings.insert("warp_any".to_string(), GpuIntrinsic::WarpAny); + mappings.insert("warp_match_any".to_string(), GpuIntrinsic::WarpMatchAny); + mappings.insert("warp_match_all".to_string(), GpuIntrinsic::WarpMatchAll); + mappings.insert("warp_reduce_add".to_string(), GpuIntrinsic::WarpReduceAdd); + mappings.insert("warp_reduce_min".to_string(), GpuIntrinsic::WarpReduceMin); + mappings.insert("warp_reduce_max".to_string(), GpuIntrinsic::WarpReduceMax); + mappings.insert("warp_reduce_and".to_string(), GpuIntrinsic::WarpReduceAnd); + mappings.insert("warp_reduce_or".to_string(), GpuIntrinsic::WarpReduceOr); + mappings.insert("warp_reduce_xor".to_string(), GpuIntrinsic::WarpReduceXor); + + // === Bit manipulation === + mappings.insert("popc".to_string(), GpuIntrinsic::Popc); + mappings.insert("popcount".to_string(), GpuIntrinsic::Popc); + mappings.insert("count_ones".to_string(), GpuIntrinsic::Popc); + mappings.insert("clz".to_string(), GpuIntrinsic::Clz); + mappings.insert("leading_zeros".to_string(), GpuIntrinsic::Clz); + mappings.insert("ctz".to_string(), GpuIntrinsic::Ctz); + mappings.insert("trailing_zeros".to_string(), GpuIntrinsic::Ctz); + mappings.insert("ffs".to_string(), GpuIntrinsic::Ffs); + mappings.insert("brev".to_string(), GpuIntrinsic::Brev); + mappings.insert("reverse_bits".to_string(), GpuIntrinsic::Brev); + mappings.insert("byte_perm".to_string(), GpuIntrinsic::BytePerm); + mappings.insert("funnel_shift_left".to_string(), GpuIntrinsic::FunnelShiftLeft); + mappings.insert("funnel_shift_right".to_string(), GpuIntrinsic::FunnelShiftRight); + + // === Memory operations === + mappings.insert("ldg".to_string(), GpuIntrinsic::Ldg); + mappings.insert("load_global".to_string(), GpuIntrinsic::Ldg); + mappings.insert("prefetch_l1".to_string(), GpuIntrinsic::PrefetchL1); + mappings.insert("prefetch_l2".to_string(), GpuIntrinsic::PrefetchL2); + + // === Special functions === + mappings.insert("rcp".to_string(), GpuIntrinsic::Rcp); + mappings.insert("recip".to_string(), GpuIntrinsic::Rcp); + mappings.insert("fdividef".to_string(), GpuIntrinsic::Fdividef); + mappings.insert("fast_div".to_string(), GpuIntrinsic::Fdividef); + mappings.insert("saturate".to_string(), GpuIntrinsic::Saturate); + mappings.insert("clamp_01".to_string(), GpuIntrinsic::Saturate); + mappings.insert("j0".to_string(), GpuIntrinsic::J0); + mappings.insert("j1".to_string(), GpuIntrinsic::J1); + mappings.insert("jn".to_string(), GpuIntrinsic::Jn); + mappings.insert("y0".to_string(), GpuIntrinsic::Y0); + mappings.insert("y1".to_string(), GpuIntrinsic::Y1); + mappings.insert("yn".to_string(), GpuIntrinsic::Yn); + mappings.insert("normcdf".to_string(), GpuIntrinsic::Normcdf); + mappings.insert("norm_cdf".to_string(), GpuIntrinsic::Normcdf); + mappings.insert("normcdfinv".to_string(), GpuIntrinsic::Normcdfinv); + mappings.insert("norm_cdf_inv".to_string(), GpuIntrinsic::Normcdfinv); + mappings.insert("cyl_bessel_i0".to_string(), GpuIntrinsic::CylBesselI0); + mappings.insert("cyl_bessel_i1".to_string(), GpuIntrinsic::CylBesselI1); + + // === Thread/block indices === mappings.insert("thread_idx_x".to_string(), GpuIntrinsic::ThreadIdxX); mappings.insert("thread_idx_y".to_string(), GpuIntrinsic::ThreadIdxY); mappings.insert("thread_idx_z".to_string(), GpuIntrinsic::ThreadIdxZ); @@ -188,6 +865,12 @@ impl IntrinsicRegistry { mappings.insert("grid_dim_x".to_string(), GpuIntrinsic::GridDimX); mappings.insert("grid_dim_y".to_string(), GpuIntrinsic::GridDimY); mappings.insert("grid_dim_z".to_string(), GpuIntrinsic::GridDimZ); + mappings.insert("warp_size".to_string(), GpuIntrinsic::WarpSize); + + // === Clock and timing === + mappings.insert("clock".to_string(), GpuIntrinsic::Clock); + mappings.insert("clock64".to_string(), GpuIntrinsic::Clock64); + mappings.insert("nanosleep".to_string(), GpuIntrinsic::Nanosleep); Self { mappings } } @@ -459,6 +1142,28 @@ impl StencilIntrinsic { } } + /// Get the index offset for 3D stencil. + /// + /// Returns (z_offset, row_offset, col_offset) where final offset is: + /// `z_offset * buffer_slice + row_offset * buffer_width + col_offset` + pub fn get_offset_3d(&self) -> Option<(i32, i32, i32)> { + match self { + StencilIntrinsic::Index => Some((0, 0, 0)), + StencilIntrinsic::North => Some((0, -1, 0)), + StencilIntrinsic::South => Some((0, 1, 0)), + StencilIntrinsic::East => Some((0, 0, 1)), + StencilIntrinsic::West => Some((0, 0, -1)), + StencilIntrinsic::Up => Some((-1, 0, 0)), + StencilIntrinsic::Down => Some((1, 0, 0)), + StencilIntrinsic::At => None, // Requires runtime offset + } + } + + /// Check if this is a 3D-only intrinsic. + pub fn is_3d_only(&self) -> bool { + matches!(self, StencilIntrinsic::Up | StencilIntrinsic::Down) + } + /// Generate CUDA index expression for 2D stencil. /// /// # Arguments @@ -483,6 +1188,43 @@ impl StencilIntrinsic { _ => format!("{}[{}]", buffer_name, idx_var), } } + + /// Generate CUDA index expression for 3D stencil. + /// + /// # Arguments + /// * `buffer_name` - Name of the buffer variable + /// * `buffer_width` - Width expression + /// * `buffer_slice` - Slice size expression (width * height) + /// * `idx_var` - Name of the current index variable + pub fn to_cuda_index_3d( + &self, + buffer_name: &str, + buffer_width: &str, + buffer_slice: &str, + idx_var: &str, + ) -> String { + match self { + StencilIntrinsic::Index => format!("{}[{}]", buffer_name, idx_var), + StencilIntrinsic::North => { + format!("{}[{} - {}]", buffer_name, idx_var, buffer_width) + } + StencilIntrinsic::South => { + format!("{}[{} + {}]", buffer_name, idx_var, buffer_width) + } + StencilIntrinsic::East => format!("{}[{} + 1]", buffer_name, idx_var), + StencilIntrinsic::West => format!("{}[{} - 1]", buffer_name, idx_var), + StencilIntrinsic::Up => { + format!("{}[{} - {}]", buffer_name, idx_var, buffer_slice) + } + StencilIntrinsic::Down => { + format!("{}[{} + {}]", buffer_name, idx_var, buffer_slice) + } + StencilIntrinsic::At => { + // This should be handled specially with provided offsets + format!("{}[{}]", buffer_name, idx_var) + } + } + } } #[cfg(test)] @@ -664,4 +1406,298 @@ mod tests { assert!(RingKernelIntrinsic::EnqueueResponse.requires_control_block()); assert!(!RingKernelIntrinsic::HlcTick.requires_control_block()); } + + // === NEW INTRINSIC TESTS === + + #[test] + fn test_new_atomic_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test bitwise atomics + assert_eq!(registry.lookup("atomic_and"), Some(&GpuIntrinsic::AtomicAnd)); + assert_eq!(registry.lookup("atomic_or"), Some(&GpuIntrinsic::AtomicOr)); + assert_eq!(registry.lookup("atomic_xor"), Some(&GpuIntrinsic::AtomicXor)); + assert_eq!(registry.lookup("atomic_inc"), Some(&GpuIntrinsic::AtomicInc)); + assert_eq!(registry.lookup("atomic_dec"), Some(&GpuIntrinsic::AtomicDec)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::AtomicAnd.to_cuda_string(), "atomicAnd"); + assert_eq!(GpuIntrinsic::AtomicOr.to_cuda_string(), "atomicOr"); + assert_eq!(GpuIntrinsic::AtomicXor.to_cuda_string(), "atomicXor"); + assert_eq!(GpuIntrinsic::AtomicInc.to_cuda_string(), "atomicInc"); + assert_eq!(GpuIntrinsic::AtomicDec.to_cuda_string(), "atomicDec"); + } + + #[test] + fn test_trigonometric_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test inverse trig + assert_eq!(registry.lookup("asin"), Some(&GpuIntrinsic::Asin)); + assert_eq!(registry.lookup("acos"), Some(&GpuIntrinsic::Acos)); + assert_eq!(registry.lookup("atan"), Some(&GpuIntrinsic::Atan)); + assert_eq!(registry.lookup("atan2"), Some(&GpuIntrinsic::Atan2)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::Asin.to_cuda_string(), "asinf"); + assert_eq!(GpuIntrinsic::Acos.to_cuda_string(), "acosf"); + assert_eq!(GpuIntrinsic::Atan.to_cuda_string(), "atanf"); + assert_eq!(GpuIntrinsic::Atan2.to_cuda_string(), "atan2f"); + } + + #[test] + fn test_hyperbolic_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test hyperbolic functions + assert_eq!(registry.lookup("sinh"), Some(&GpuIntrinsic::Sinh)); + assert_eq!(registry.lookup("cosh"), Some(&GpuIntrinsic::Cosh)); + assert_eq!(registry.lookup("tanh"), Some(&GpuIntrinsic::Tanh)); + assert_eq!(registry.lookup("asinh"), Some(&GpuIntrinsic::Asinh)); + assert_eq!(registry.lookup("acosh"), Some(&GpuIntrinsic::Acosh)); + assert_eq!(registry.lookup("atanh"), Some(&GpuIntrinsic::Atanh)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::Sinh.to_cuda_string(), "sinhf"); + assert_eq!(GpuIntrinsic::Cosh.to_cuda_string(), "coshf"); + assert_eq!(GpuIntrinsic::Tanh.to_cuda_string(), "tanhf"); + } + + #[test] + fn test_exponential_logarithmic_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test exp variants + assert_eq!(registry.lookup("exp2"), Some(&GpuIntrinsic::Exp2)); + assert_eq!(registry.lookup("exp10"), Some(&GpuIntrinsic::Exp10)); + assert_eq!(registry.lookup("expm1"), Some(&GpuIntrinsic::Expm1)); + + // Test log variants + assert_eq!(registry.lookup("log2"), Some(&GpuIntrinsic::Log2)); + assert_eq!(registry.lookup("log10"), Some(&GpuIntrinsic::Log10)); + assert_eq!(registry.lookup("log1p"), Some(&GpuIntrinsic::Log1p)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::Exp2.to_cuda_string(), "exp2f"); + assert_eq!(GpuIntrinsic::Log2.to_cuda_string(), "log2f"); + assert_eq!(GpuIntrinsic::Log10.to_cuda_string(), "log10f"); + } + + #[test] + fn test_classification_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test classification functions + assert_eq!(registry.lookup("is_nan"), Some(&GpuIntrinsic::Isnan)); + assert_eq!(registry.lookup("isnan"), Some(&GpuIntrinsic::Isnan)); + assert_eq!(registry.lookup("is_infinite"), Some(&GpuIntrinsic::Isinf)); + assert_eq!(registry.lookup("is_finite"), Some(&GpuIntrinsic::Isfinite)); + assert_eq!(registry.lookup("is_normal"), Some(&GpuIntrinsic::Isnormal)); + assert_eq!(registry.lookup("signbit"), Some(&GpuIntrinsic::Signbit)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::Isnan.to_cuda_string(), "isnan"); + assert_eq!(GpuIntrinsic::Isinf.to_cuda_string(), "isinf"); + assert_eq!(GpuIntrinsic::Isfinite.to_cuda_string(), "isfinite"); + } + + #[test] + fn test_warp_reduce_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test warp reduce operations + assert_eq!(registry.lookup("warp_reduce_add"), Some(&GpuIntrinsic::WarpReduceAdd)); + assert_eq!(registry.lookup("warp_reduce_min"), Some(&GpuIntrinsic::WarpReduceMin)); + assert_eq!(registry.lookup("warp_reduce_max"), Some(&GpuIntrinsic::WarpReduceMax)); + assert_eq!(registry.lookup("warp_reduce_and"), Some(&GpuIntrinsic::WarpReduceAnd)); + assert_eq!(registry.lookup("warp_reduce_or"), Some(&GpuIntrinsic::WarpReduceOr)); + assert_eq!(registry.lookup("warp_reduce_xor"), Some(&GpuIntrinsic::WarpReduceXor)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::WarpReduceAdd.to_cuda_string(), "__reduce_add_sync"); + assert_eq!(GpuIntrinsic::WarpReduceMin.to_cuda_string(), "__reduce_min_sync"); + assert_eq!(GpuIntrinsic::WarpReduceMax.to_cuda_string(), "__reduce_max_sync"); + } + + #[test] + fn test_warp_match_intrinsics() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("warp_match_any"), Some(&GpuIntrinsic::WarpMatchAny)); + assert_eq!(registry.lookup("warp_match_all"), Some(&GpuIntrinsic::WarpMatchAll)); + + assert_eq!(GpuIntrinsic::WarpMatchAny.to_cuda_string(), "__match_any_sync"); + assert_eq!(GpuIntrinsic::WarpMatchAll.to_cuda_string(), "__match_all_sync"); + } + + #[test] + fn test_bit_manipulation_intrinsics() { + let registry = IntrinsicRegistry::new(); + + // Test bit manipulation + assert_eq!(registry.lookup("popc"), Some(&GpuIntrinsic::Popc)); + assert_eq!(registry.lookup("popcount"), Some(&GpuIntrinsic::Popc)); + assert_eq!(registry.lookup("count_ones"), Some(&GpuIntrinsic::Popc)); + assert_eq!(registry.lookup("clz"), Some(&GpuIntrinsic::Clz)); + assert_eq!(registry.lookup("leading_zeros"), Some(&GpuIntrinsic::Clz)); + assert_eq!(registry.lookup("ctz"), Some(&GpuIntrinsic::Ctz)); + assert_eq!(registry.lookup("ffs"), Some(&GpuIntrinsic::Ffs)); + assert_eq!(registry.lookup("brev"), Some(&GpuIntrinsic::Brev)); + assert_eq!(registry.lookup("reverse_bits"), Some(&GpuIntrinsic::Brev)); + + // Test CUDA output + assert_eq!(GpuIntrinsic::Popc.to_cuda_string(), "__popc"); + assert_eq!(GpuIntrinsic::Clz.to_cuda_string(), "__clz"); + assert_eq!(GpuIntrinsic::Ffs.to_cuda_string(), "__ffs"); + assert_eq!(GpuIntrinsic::Brev.to_cuda_string(), "__brev"); + } + + #[test] + fn test_funnel_shift_intrinsics() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("funnel_shift_left"), Some(&GpuIntrinsic::FunnelShiftLeft)); + assert_eq!(registry.lookup("funnel_shift_right"), Some(&GpuIntrinsic::FunnelShiftRight)); + + assert_eq!(GpuIntrinsic::FunnelShiftLeft.to_cuda_string(), "__funnelshift_l"); + assert_eq!(GpuIntrinsic::FunnelShiftRight.to_cuda_string(), "__funnelshift_r"); + } + + #[test] + fn test_memory_intrinsics() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("ldg"), Some(&GpuIntrinsic::Ldg)); + assert_eq!(registry.lookup("load_global"), Some(&GpuIntrinsic::Ldg)); + assert_eq!(registry.lookup("prefetch_l1"), Some(&GpuIntrinsic::PrefetchL1)); + assert_eq!(registry.lookup("prefetch_l2"), Some(&GpuIntrinsic::PrefetchL2)); + + assert_eq!(GpuIntrinsic::Ldg.to_cuda_string(), "__ldg"); + assert_eq!(GpuIntrinsic::PrefetchL1.to_cuda_string(), "__prefetch_l1"); + assert_eq!(GpuIntrinsic::PrefetchL2.to_cuda_string(), "__prefetch_l2"); + } + + #[test] + fn test_clock_intrinsics() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("clock"), Some(&GpuIntrinsic::Clock)); + assert_eq!(registry.lookup("clock64"), Some(&GpuIntrinsic::Clock64)); + assert_eq!(registry.lookup("nanosleep"), Some(&GpuIntrinsic::Nanosleep)); + + assert_eq!(GpuIntrinsic::Clock.to_cuda_string(), "clock()"); + assert_eq!(GpuIntrinsic::Clock64.to_cuda_string(), "clock64()"); + assert_eq!(GpuIntrinsic::Nanosleep.to_cuda_string(), "__nanosleep"); + } + + #[test] + fn test_special_function_intrinsics() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("rcp"), Some(&GpuIntrinsic::Rcp)); + assert_eq!(registry.lookup("recip"), Some(&GpuIntrinsic::Rcp)); + assert_eq!(registry.lookup("saturate"), Some(&GpuIntrinsic::Saturate)); + assert_eq!(registry.lookup("clamp_01"), Some(&GpuIntrinsic::Saturate)); + + assert_eq!(GpuIntrinsic::Rcp.to_cuda_string(), "__frcp_rn"); + assert_eq!(GpuIntrinsic::Saturate.to_cuda_string(), "__saturatef"); + } + + #[test] + fn test_intrinsic_categories() { + // Test category assignment + assert_eq!(GpuIntrinsic::SyncThreads.category(), "synchronization"); + assert_eq!(GpuIntrinsic::AtomicAdd.category(), "atomic"); + assert_eq!(GpuIntrinsic::Sqrt.category(), "math"); + assert_eq!(GpuIntrinsic::Sin.category(), "trigonometric"); + assert_eq!(GpuIntrinsic::Sinh.category(), "hyperbolic"); + assert_eq!(GpuIntrinsic::Exp.category(), "exponential"); + assert_eq!(GpuIntrinsic::Isnan.category(), "classification"); + assert_eq!(GpuIntrinsic::WarpShfl.category(), "warp"); + assert_eq!(GpuIntrinsic::Popc.category(), "bit"); + assert_eq!(GpuIntrinsic::Ldg.category(), "memory"); + assert_eq!(GpuIntrinsic::Rcp.category(), "special"); + assert_eq!(GpuIntrinsic::ThreadIdxX.category(), "index"); + assert_eq!(GpuIntrinsic::Clock.category(), "timing"); + } + + #[test] + fn test_intrinsic_flags() { + // Test is_value_intrinsic + assert!(GpuIntrinsic::ThreadIdxX.is_value_intrinsic()); + assert!(GpuIntrinsic::BlockDimX.is_value_intrinsic()); + assert!(GpuIntrinsic::WarpSize.is_value_intrinsic()); + assert!(!GpuIntrinsic::Sin.is_value_intrinsic()); + assert!(!GpuIntrinsic::AtomicAdd.is_value_intrinsic()); + + // Test is_zero_arg_function + assert!(GpuIntrinsic::SyncThreads.is_zero_arg_function()); + assert!(GpuIntrinsic::ThreadFence.is_zero_arg_function()); + assert!(GpuIntrinsic::WarpActiveMask.is_zero_arg_function()); + assert!(GpuIntrinsic::Clock.is_zero_arg_function()); + assert!(!GpuIntrinsic::Sin.is_zero_arg_function()); + + // Test requires_mask + assert!(GpuIntrinsic::WarpShfl.requires_mask()); + assert!(GpuIntrinsic::WarpBallot.requires_mask()); + assert!(GpuIntrinsic::WarpReduceAdd.requires_mask()); + assert!(!GpuIntrinsic::Sin.requires_mask()); + assert!(!GpuIntrinsic::AtomicAdd.requires_mask()); + } + + #[test] + fn test_3d_stencil_intrinsics() { + assert_eq!(StencilIntrinsic::from_method_name("up"), Some(StencilIntrinsic::Up)); + assert_eq!(StencilIntrinsic::from_method_name("down"), Some(StencilIntrinsic::Down)); + + // Test 3D only flag + assert!(StencilIntrinsic::Up.is_3d_only()); + assert!(StencilIntrinsic::Down.is_3d_only()); + assert!(!StencilIntrinsic::North.is_3d_only()); + assert!(!StencilIntrinsic::East.is_3d_only()); + assert!(!StencilIntrinsic::Index.is_3d_only()); + + // Test 3D offsets + assert_eq!(StencilIntrinsic::Up.get_offset_3d(), Some((-1, 0, 0))); + assert_eq!(StencilIntrinsic::Down.get_offset_3d(), Some((1, 0, 0))); + assert_eq!(StencilIntrinsic::North.get_offset_3d(), Some((0, -1, 0))); + assert_eq!(StencilIntrinsic::South.get_offset_3d(), Some((0, 1, 0))); + assert_eq!(StencilIntrinsic::East.get_offset_3d(), Some((0, 0, 1))); + assert_eq!(StencilIntrinsic::West.get_offset_3d(), Some((0, 0, -1))); + + // Test 3D index generation + let up = StencilIntrinsic::Up; + assert_eq!(up.to_cuda_index_3d("p", "18", "324", "idx"), "p[idx - 324]"); + + let down = StencilIntrinsic::Down; + assert_eq!(down.to_cuda_index_3d("p", "18", "324", "idx"), "p[idx + 324]"); + } + + #[test] + fn test_sync_intrinsics() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("sync_threads_count"), Some(&GpuIntrinsic::SyncThreadsCount)); + assert_eq!(registry.lookup("sync_threads_and"), Some(&GpuIntrinsic::SyncThreadsAnd)); + assert_eq!(registry.lookup("sync_threads_or"), Some(&GpuIntrinsic::SyncThreadsOr)); + + assert_eq!(GpuIntrinsic::SyncThreadsCount.to_cuda_string(), "__syncthreads_count"); + assert_eq!(GpuIntrinsic::SyncThreadsAnd.to_cuda_string(), "__syncthreads_and"); + assert_eq!(GpuIntrinsic::SyncThreadsOr.to_cuda_string(), "__syncthreads_or"); + } + + #[test] + fn test_math_extras() { + let registry = IntrinsicRegistry::new(); + + assert_eq!(registry.lookup("trunc"), Some(&GpuIntrinsic::Trunc)); + assert_eq!(registry.lookup("cbrt"), Some(&GpuIntrinsic::Cbrt)); + assert_eq!(registry.lookup("hypot"), Some(&GpuIntrinsic::Hypot)); + assert_eq!(registry.lookup("copysign"), Some(&GpuIntrinsic::Copysign)); + assert_eq!(registry.lookup("fmod"), Some(&GpuIntrinsic::Fmod)); + + assert_eq!(GpuIntrinsic::Trunc.to_cuda_string(), "truncf"); + assert_eq!(GpuIntrinsic::Cbrt.to_cuda_string(), "cbrtf"); + assert_eq!(GpuIntrinsic::Hypot.to_cuda_string(), "hypotf"); + } } diff --git a/crates/ringkernel-cuda-codegen/src/transpiler.rs b/crates/ringkernel-cuda-codegen/src/transpiler.rs index ec0bafc..85018be 100644 --- a/crates/ringkernel-cuda-codegen/src/transpiler.rs +++ b/crates/ringkernel-cuda-codegen/src/transpiler.rs @@ -814,11 +814,24 @@ impl CudaTranspiler { })?; let buffer_width = config.buffer_width().to_string(); + let buffer_slice = format!( + "{}", + config.buffer_width() * config.buffer_height() + ); + let is_3d = config.grid == crate::stencil::Grid::Grid3D; let intrinsic = StencilIntrinsic::from_method_name(method).ok_or_else(|| { TranspileError::Unsupported(format!("Unknown stencil intrinsic: {method}")) })?; + // Check if 3D intrinsic used in non-3D kernel + if intrinsic.is_3d_only() && !is_3d { + return Err(TranspileError::Unsupported(format!( + "3D stencil intrinsic '{}' requires Grid3D configuration", + method + ))); + } + match intrinsic { StencilIntrinsic::Index => { // pos.idx() -> idx @@ -835,25 +848,49 @@ impl CudaTranspiler { )); } let buffer = self.transpile_expr(&args[0])?; - Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx")) + if is_3d { + Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx")) + } else { + Ok(intrinsic.to_cuda_index_2d(&buffer, &buffer_width, "idx")) + } } - StencilIntrinsic::At => { - // pos.at(buf, dx, dy) -> buf[idx + dy * buffer_width + dx] - if args.len() < 3 { + StencilIntrinsic::Up | StencilIntrinsic::Down => { + // 3D intrinsics: pos.up(buf) -> buf[idx - buffer_slice] + if args.is_empty() { return Err(TranspileError::Unsupported( - "at() requires buffer, dx, dy arguments".into(), + "3D stencil accessor requires buffer argument".into(), )); } let buffer = self.transpile_expr(&args[0])?; - let dx = self.transpile_expr(&args[1])?; - let dy = self.transpile_expr(&args[2])?; - Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]")) + Ok(intrinsic.to_cuda_index_3d(&buffer, &buffer_width, &buffer_slice, "idx")) } - StencilIntrinsic::Up | StencilIntrinsic::Down => { - // 3D intrinsics - Err(TranspileError::Unsupported( - "3D stencil intrinsics not yet implemented".into(), - )) + StencilIntrinsic::At => { + // 2D: pos.at(buf, dx, dy) -> buf[idx + dy * buffer_width + dx] + // 3D: pos.at(buf, dx, dy, dz) -> buf[idx + dz * buffer_slice + dy * buffer_width + dx] + if is_3d { + if args.len() < 4 { + return Err(TranspileError::Unsupported( + "at() in 3D requires buffer, dx, dy, dz arguments".into(), + )); + } + let buffer = self.transpile_expr(&args[0])?; + let dx = self.transpile_expr(&args[1])?; + let dy = self.transpile_expr(&args[2])?; + let dz = self.transpile_expr(&args[3])?; + Ok(format!( + "{buffer}[idx + ({dz}) * {buffer_slice} + ({dy}) * {buffer_width} + ({dx})]" + )) + } else { + if args.len() < 3 { + return Err(TranspileError::Unsupported( + "at() requires buffer, dx, dy arguments".into(), + )); + } + let buffer = self.transpile_expr(&args[0])?; + let dx = self.transpile_expr(&args[1])?; + let dy = self.transpile_expr(&args[2])?; + Ok(format!("{buffer}[idx + ({dy}) * {buffer_width} + ({dx})]")) + } } } }