A custom inference engine for Mixture of Experts (MoE) models engineered from scratch. Features handwritten OpenAI Triton kernels for fused routing and a Python-based continuous batching scheduler.
Standard PyTorch inference for MoE models (like Qwen1.5-MoE) is inefficient for two reasons:
- Memory Bottleneck: The routing mechanism uses
torch.topkandscatter/gatheroperations, which trigger excessive read/write passes to GPU High-Bandwidth Memory (HBM). - Dispatch Overhead: Serial request processing leaves the GPU idle between tokens while the CPU prepares the next step.
Goal: Build a system that fuses the routing logic into a single GPU kernel and implements continuous batching to maximize throughput on consumer hardware (12GB VRAM).
Running on NVIDIA RTX 3060 (12GB) with Qwen1.5-MoE-A2.7B:
| Metric | Standard Serial API | MoE-Triton-Serve | Improvement |
|---|---|---|---|
| Throughput | 11.87 tokens/sec | 20.38 tokens/sec | +72% (1.72x) 🚀 |
| Latency (P50) | ~84 ms/token | ~49 ms/token | -41% |
| VRAM Usage | ~7.9 GB | ~7.9 GB | (Parity) |
(Run via python -m benchmarks.final_benchmark)
Replaced the standard PyTorch routing block with a fused OpenAI Triton kernel.
- Fusion: Performs Logit Loading → Softmax → Top-K Selection → Normalization in a single SRAM pass.
- Result: Eliminates intermediate memory writes, significantly reducing kernel launch latency.
Implemented a "Cellular" batching strategy to handle dynamic requests.
- Dynamic Queue: Incoming requests are added to a running batch immediately when a slot opens (no waiting for the previous batch to finish).
- KV-Cache Management: Manages Key-Value states for active slots to prevent re-computation.
- Throughput Gain: Hides the Python CPU dispatch overhead by processing 8 requests in parallel.
- Used NF4 (NormalFloat 4-bit) quantization via
bitsandbytes. - Allowed a 14B parameter model (which normally requires ~28GB VRAM) to run comfortably on a 12GB card while maintaining routing accuracy.
If you are reviewing the code, these are the critical files where the custom engineering logic resides:
kernels/router_kernel.py- What it is: The custom GPU kernel written in OpenAI Triton.
- Key Logic: Implements
_fused_router_kernelwhich fuses the Softmax, Top-K selection, and weight normalization into a single function. - Why it matters: This is the specific optimization that removes the memory bandwidth bottleneck found in standard PyTorch routing.
-
engine/scheduler.py- What it is: The implementation of Continuous Batching logic.
- Key Logic: A
RequestQueueclass that manages incoming prompts and aSchedulerthat dynamically fills active batch slots. - Why it matters: Demonstrates how to move away from static batching to dynamic, asynchronous request processing (similar to vLLM).
-
engine/moe_engine.py- What it is: The main inference loop.
- Key Logic: Manages the KV-Cache, handles the "Prefill" vs "Decode" phases, and executes the generation loop token-by-token.
engine/monkey_patch.py- What it is: The "glue" code.
- Key Logic: Dynamically replaces (
hot-swaps) the standardQwen2MoeRouterin the HuggingFace Transformers library with our custom Triton kernel at runtime.
benchmarks/final_benchmark.py- What it is: The script used to generate the results table.
- Key Logic: runs a controlled experiment comparing a serial baseline (Standard API simulation) vs. the Batched Engine.
- Model:
Qwen/Qwen1.5-MoE-A2.7B - Kernels: OpenAI Triton
- Engine Logic: Python, PyTorch
- Profiling: Nsight Systems, PyTorch Profiler
- Environment: WSL2 (Ubuntu 22.04)
-
Clone & Install
git clone https://github.com/yourusername/moe-triton-serve.git cd moe-triton-serve pip install torch transformers triton bitsandbytes accelerate pandas -
Run the Benchmark Comparison
# Runs the serial baseline vs. the custom engine python -m benchmarks.final_benchmark -
Run Correctness Tests
# Verifies the Triton kernel produces the same math as PyTorch python tests/test_router.py
- The Python Bottleneck: Profiling revealed that even with a fast GPU kernel, Python's dispatch loop (CPU) is the ultimate limit on throughput for small models. A C++ scheduler (like vLLM) would be required for >5x gains.
- WSL2 Limitations: Direct GPU profiling (CUPTI) is restricted in WSL2, requiring indirect analysis via CPU-side launch metrics.
- Kernel Fusion: For bandwidth-bound operations (like MoE Routing), fusing operations into one kernel provides larger speedups than matrix math optimizations.