Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,15 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
```

**DSL Features:**
- Block/grid indices: `block_idx_x()`, `thread_idx_x()`, `block_dim_x()`, `grid_dim_x()`, etc.
- Block/grid indices: `block_idx_x()`, `thread_idx_x()`, `block_dim_x()`, `grid_dim_x()`, `warp_size()`, etc.
- Control flow: `if/else`, `match` → switch/case, early `return`
- Loops: `for i in 0..n`, `while cond`, `loop` with `break`/`continue`
- Stencil intrinsics: `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`, `pos.at(buf, dx, dy)`
- Stencil intrinsics (2D): `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`, `pos.at(buf, dx, dy)`
- Stencil intrinsics (3D): `pos.up(buf)`, `pos.down(buf)`, `pos.at(buf, dx, dy, dz)` for volumetric kernels
- Shared memory: `__shared__` arrays and tiles with `SharedMemoryConfig`
- Struct literals: `Point { x: 1.0, y: 2.0 }` → C compound literals
- Reference expressions: `&arr[idx]` → pointer to element with automatic `->` operator for field access
- 45+ GPU intrinsics (atomics, warp ops, sync, math)
- 120+ GPU intrinsics across 13 categories (synchronization, atomics, math, trig, hyperbolic, exponential, classification, warp, bit manipulation, memory, special, index, timing)

**Ring Kernel Features:**
- Persistent message loop with ControlBlock lifecycle management
Expand Down Expand Up @@ -282,7 +283,7 @@ Main crate (`ringkernel`) features:
- ringkernel-core: 65 tests
- ringkernel-cpu: 11 tests
- ringkernel-cuda: 6 GPU execution tests
- ringkernel-cuda-codegen: 143 tests (loops, shared memory, ring kernels, K2K, reference expressions)
- ringkernel-cuda-codegen: 171 tests (loops, shared memory, ring kernels, K2K, reference expressions, 120+ GPU intrinsics)
- ringkernel-wgpu-codegen: 50 tests (types, intrinsics, transpiler, validation)
- ringkernel-derive: 14 macro tests
- ringkernel-wavesim: 49 tests (including educational modes)
Expand Down
202 changes: 176 additions & 26 deletions crates/ringkernel-cuda-codegen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Rust-to-CUDA transpiler for RingKernel GPU kernels.
This crate enables writing GPU kernels in a restricted Rust DSL and transpiling them to CUDA C code. It supports three kernel types:

1. **Global Kernels** - Standard CUDA `__global__` functions
2. **Stencil Kernels** - Tile-based kernels with `GridPos` abstraction
2. **Stencil Kernels** - Tile-based kernels with `GridPos` abstraction (2D and 3D)
3. **Ring Kernels** - Persistent actor kernels with message loops

## Installation
Expand Down Expand Up @@ -39,11 +39,12 @@ let cuda_code = transpile_global_kernel(&func)?;

## Stencil Kernels

For grid-based computations with neighbor access:
For grid-based computations with neighbor access (2D and 3D):

```rust
use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig};
use ringkernel_cuda_codegen::{transpile_stencil_kernel, StencilConfig, Grid};

// 2D stencil
let func: syn::ItemFn = parse_quote! {
fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p)
Expand All @@ -53,10 +54,25 @@ let func: syn::ItemFn = parse_quote! {
};

let config = StencilConfig::new("fdtd")
.with_grid(Grid::Grid2D)
.with_tile_size(16, 16)
.with_halo(1);

let cuda_code = transpile_stencil_kernel(&func, &config)?;

// 3D stencil with up/down neighbors
let func_3d: syn::ItemFn = parse_quote! {
fn laplacian_3d(p: &[f32], out: &mut [f32], pos: GridPos) {
let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p)
+ pos.up(p) + pos.down(p) - 6.0 * p[pos.idx()];
out[pos.idx()] = lap;
}
};

let config_3d = StencilConfig::new("laplacian")
.with_grid(Grid::Grid3D)
.with_tile_size(8, 8)
.with_halo(1);
```

## Ring Kernels
Expand Down Expand Up @@ -86,35 +102,149 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
## DSL Reference

### Thread/Block Indices
- `thread_idx_x()`, `thread_idx_y()`, `thread_idx_z()`
- `block_idx_x()`, `block_idx_y()`, `block_idx_z()`
- `block_dim_x()`, `block_dim_y()`, `block_dim_z()`
- `grid_dim_x()`, `grid_dim_y()`, `grid_dim_z()`
- `thread_idx_x()`, `thread_idx_y()`, `thread_idx_z()` → `threadIdx.x/y/z`
- `block_idx_x()`, `block_idx_y()`, `block_idx_z()` → `blockIdx.x/y/z`
- `block_dim_x()`, `block_dim_y()`, `block_dim_z()` → `blockDim.x/y/z`
- `grid_dim_x()`, `grid_dim_y()`, `grid_dim_z()` → `gridDim.x/y/z`
- `warp_size()` → `warpSize`

### Stencil Intrinsics
### Stencil Intrinsics (2D)
- `pos.idx()` - Linear index
- `pos.north(buf)`, `pos.south(buf)`, `pos.east(buf)`, `pos.west(buf)`
- `pos.north(buf)`, `pos.south(buf)` - Y-axis neighbors
- `pos.east(buf)`, `pos.west(buf)` - X-axis neighbors
- `pos.at(buf, dx, dy)` - Relative offset access

### Synchronization
- `sync_threads()` - Block-level barrier
- `thread_fence()` - Device memory fence
- `thread_fence_block()` - Block memory fence

### Atomics
- `atomic_add(ptr, val)`, `atomic_sub(ptr, val)`
- `atomic_min(ptr, val)`, `atomic_max(ptr, val)`
- `atomic_exchange(ptr, val)`, `atomic_cas(ptr, compare, val)`
### Stencil Intrinsics (3D)
- `pos.up(buf)`, `pos.down(buf)` - Z-axis neighbors
- `pos.at(buf, dx, dy, dz)` - 3D relative offset access

### Math Functions
- `sqrt()`, `abs()`, `floor()`, `ceil()`, `round()`
- `sin()`, `cos()`, `tan()`, `exp()`, `log()`
- `powf()`, `min()`, `max()`, `mul_add()`
### Synchronization
- `sync_threads()` → `__syncthreads()` - Block-level barrier
- `sync_threads_count(pred)` → `__syncthreads_count()` - Count threads with predicate
- `sync_threads_and(pred)` → `__syncthreads_and()` - AND of predicate
- `sync_threads_or(pred)` → `__syncthreads_or()` - OR of predicate
- `thread_fence()` → `__threadfence()` - Device memory fence
- `thread_fence_block()` → `__threadfence_block()` - Block memory fence
- `thread_fence_system()` → `__threadfence_system()` - System memory fence

### Atomic Operations (Integer)
- `atomic_add(ptr, val)` → `atomicAdd`
- `atomic_sub(ptr, val)` → `atomicSub`
- `atomic_min(ptr, val)` → `atomicMin`
- `atomic_max(ptr, val)` → `atomicMax`
- `atomic_exchange(ptr, val)` → `atomicExch`
- `atomic_cas(ptr, compare, val)` → `atomicCAS`
- `atomic_and(ptr, val)` → `atomicAnd`
- `atomic_or(ptr, val)` → `atomicOr`
- `atomic_xor(ptr, val)` → `atomicXor`
- `atomic_inc(ptr, val)` → `atomicInc` (increment with wrap)
- `atomic_dec(ptr, val)` → `atomicDec` (decrement with wrap)

### Basic Math Functions
- `sqrt()`, `rsqrt()` - Square root, reciprocal sqrt
- `abs()`, `fabs()` - Absolute value
- `floor()`, `ceil()`, `round()`, `trunc()` - Rounding
- `fma()`, `mul_add()` - Fused multiply-add
- `fmin()`, `fmax()` - Minimum, maximum
- `fmod()`, `remainder()` - Modulo operations
- `copysign()` - Copy sign
- `cbrt()` - Cube root
- `hypot()` - Hypotenuse

### Trigonometric Functions
- `sin()`, `cos()`, `tan()` - Basic trig
- `asin()`, `acos()`, `atan()`, `atan2()` - Inverse trig
- `sincos()` - Combined sine and cosine
- `sinpi()`, `cospi()` - Sin/cos of π*x

### Hyperbolic Functions
- `sinh()`, `cosh()`, `tanh()` - Hyperbolic
- `asinh()`, `acosh()`, `atanh()` - Inverse hyperbolic

### Exponential and Logarithmic Functions
- `exp()`, `exp2()`, `exp10()`, `expm1()` - Exponentials
- `log()`, `ln()`, `log2()`, `log10()`, `log1p()` - Logarithms
- `pow()`, `powf()`, `powi()` - Power
- `ldexp()`, `scalbn()` - Load/scale exponent
- `ilogb()` - Extract exponent
- `erf()`, `erfc()`, `erfinv()`, `erfcinv()` - Error functions
- `lgamma()`, `tgamma()` - Gamma functions

### Classification Functions
- `is_nan()`, `isnan()` → `isnan`
- `is_infinite()`, `isinf()` → `isinf`
- `is_finite()`, `isfinite()` → `isfinite`
- `is_normal()`, `isnormal()` → `isnormal`
- `signbit()` - Check sign bit
- `nextafter()` - Next representable value
- `fdim()` - Positive difference

### Warp Operations
- `warp_shuffle(val, lane)`, `warp_shuffle_up(val, delta)`
- `warp_shuffle_down(val, delta)`, `warp_shuffle_xor(val, mask)`
- `warp_ballot(pred)`, `warp_all(pred)`, `warp_any(pred)`
- `warp_active_mask()` → `__activemask()` - Active lane mask
- `warp_shfl(mask, val, lane)` → `__shfl_sync` - Shuffle
- `warp_shfl_up(mask, val, delta)` → `__shfl_up_sync`
- `warp_shfl_down(mask, val, delta)` → `__shfl_down_sync`
- `warp_shfl_xor(mask, val, lane_mask)` → `__shfl_xor_sync`
- `warp_ballot(mask, pred)` → `__ballot_sync`
- `warp_all(mask, pred)` → `__all_sync`
- `warp_any(mask, pred)` → `__any_sync`

### Warp Match Operations (Volta+)
- `warp_match_any(mask, val)` → `__match_any_sync`
- `warp_match_all(mask, val)` → `__match_all_sync`

### Warp Reduce Operations (SM 8.0+)
- `warp_reduce_add(mask, val)` → `__reduce_add_sync`
- `warp_reduce_min(mask, val)` → `__reduce_min_sync`
- `warp_reduce_max(mask, val)` → `__reduce_max_sync`
- `warp_reduce_and(mask, val)` → `__reduce_and_sync`
- `warp_reduce_or(mask, val)` → `__reduce_or_sync`
- `warp_reduce_xor(mask, val)` → `__reduce_xor_sync`

### Bit Manipulation
- `popc()`, `popcount()`, `count_ones()` → `__popc` - Population count
- `clz()`, `leading_zeros()` → `__clz` - Count leading zeros
- `ctz()`, `trailing_zeros()` → `__ffs - 1` - Count trailing zeros
- `ffs()` → `__ffs` - Find first set
- `brev()`, `reverse_bits()` → `__brev` - Bit reverse
- `byte_perm()` → `__byte_perm` - Byte permutation
- `funnel_shift_left()` → `__funnelshift_l`
- `funnel_shift_right()` → `__funnelshift_r`

### Memory Operations
- `ldg(ptr)`, `load_global(ptr)` → `__ldg` - Read-only cache load
- `prefetch_l1(ptr)` → `__prefetch_l1` - L1 prefetch
- `prefetch_l2(ptr)` → `__prefetch_l2` - L2 prefetch

### Special Functions
- `rcp()`, `recip()` → `__frcp_rn` - Fast reciprocal
- `fast_div()` → `__fdividef` - Fast division
- `saturate()`, `clamp_01()` → `__saturatef` - Saturate to [0,1]
- `j0()`, `j1()`, `jn()` - Bessel functions of first kind
- `y0()`, `y1()`, `yn()` - Bessel functions of second kind
- `normcdf()`, `normcdfinv()` - Normal CDF
- `cyl_bessel_i0()`, `cyl_bessel_i1()` - Cylindrical Bessel functions

### Clock and Timing
- `clock()` → `clock()` - 32-bit clock counter
- `clock64()` → `clock64()` - 64-bit clock counter
- `nanosleep(ns)` → `__nanosleep` - Sleep for nanoseconds

### RingContext Methods
- `ctx.thread_id()` → `threadIdx.x`
- `ctx.block_id()` → `blockIdx.x`
- `ctx.global_thread_id()` → `(blockIdx.x * blockDim.x + threadIdx.x)`
- `ctx.sync_threads()` → `__syncthreads()`
- `ctx.lane_id()` → `(threadIdx.x % 32)`
- `ctx.warp_id()` → `(threadIdx.x / 32)`

### Ring Kernel Intrinsics
- `is_active()`, `should_terminate()`, `mark_terminated()`
- `messages_processed()`, `input_queue_size()`, `output_queue_size()`
- `input_queue_empty()`, `output_queue_empty()`, `enqueue_response(&resp)`
- `hlc_tick()`, `hlc_update(ts)`, `hlc_now()` - HLC operations
- `k2k_send(target, &msg)`, `k2k_try_recv()` - K2K messaging
- `k2k_has_message()`, `k2k_peek()`, `k2k_pending_count()`

## Type Mapping

Expand All @@ -130,13 +260,33 @@ let cuda_code = transpile_ring_kernel(&handler, &config)?;
| `&[T]` | `const T* __restrict__` |
| `&mut [T]` | `T* __restrict__` |

## Intrinsic Count

The transpiler supports **120+ GPU intrinsics** across 13 categories:

| Category | Count | Examples |
|----------|-------|----------|
| Synchronization | 7 | `sync_threads`, `thread_fence` |
| Atomics | 11 | `atomic_add`, `atomic_cas`, `atomic_and` |
| Math | 16 | `sqrt`, `fma`, `cbrt`, `hypot` |
| Trigonometric | 11 | `sin`, `asin`, `atan2`, `sincos` |
| Hyperbolic | 6 | `sinh`, `asinh` |
| Exponential | 18 | `exp`, `log2`, `erf`, `gamma` |
| Classification | 8 | `isnan`, `isfinite`, `signbit` |
| Warp | 16 | `warp_shfl`, `warp_reduce_add`, `warp_match_any` |
| Bit Manipulation | 8 | `popc`, `clz`, `brev`, `funnel_shift_left` |
| Memory | 3 | `ldg`, `prefetch_l1` |
| Special | 13 | `rcp`, `saturate`, `normcdf` |
| Index | 13 | `thread_idx_x`, `warp_size` |
| Timing | 3 | `clock`, `clock64`, `nanosleep` |

## Testing

```bash
cargo test -p ringkernel-cuda-codegen
```

The crate includes 143 tests covering all kernel types and language features.
The crate includes 171 tests covering all kernel types, intrinsics, and language features.

## License

Expand Down
Loading
Loading