Skip to content

abi2024/FlashMoE-Serve

Repository files navigation

MoE-Triton-Serve: High-Performance Inference Engine

Python PyTorch Triton Hardware

A custom inference engine for Mixture of Experts (MoE) models engineered from scratch. Features handwritten OpenAI Triton kernels for fused routing and a Python-based continuous batching scheduler.

🎯 The Core Thesis (Why I built this)

Standard PyTorch inference for MoE models (like Qwen1.5-MoE) is inefficient for two reasons:

  1. Memory Bottleneck: The routing mechanism uses torch.topk and scatter/gather operations, which trigger excessive read/write passes to GPU High-Bandwidth Memory (HBM).
  2. Dispatch Overhead: Serial request processing leaves the GPU idle between tokens while the CPU prepares the next step.

Goal: Build a system that fuses the routing logic into a single GPU kernel and implements continuous batching to maximize throughput on consumer hardware (12GB VRAM).


📊 Benchmark Results

Running on NVIDIA RTX 3060 (12GB) with Qwen1.5-MoE-A2.7B:

Metric Standard Serial API MoE-Triton-Serve Improvement
Throughput 11.87 tokens/sec 20.38 tokens/sec +72% (1.72x) 🚀
Latency (P50) ~84 ms/token ~49 ms/token -41%
VRAM Usage ~7.9 GB ~7.9 GB (Parity)

📸 Proof of Performance

Benchmark Console Output (Run via python -m benchmarks.final_benchmark)


🏗️ System Architecture

1. Custom Triton Router Kernel (kernels/router_kernel.py)

Replaced the standard PyTorch routing block with a fused OpenAI Triton kernel.

  • Fusion: Performs Logit Loading → Softmax → Top-K Selection → Normalization in a single SRAM pass.
  • Result: Eliminates intermediate memory writes, significantly reducing kernel launch latency.

2. Continuous Batching Scheduler (engine/scheduler.py)

Implemented a "Cellular" batching strategy to handle dynamic requests.

  • Dynamic Queue: Incoming requests are added to a running batch immediately when a slot opens (no waiting for the previous batch to finish).
  • KV-Cache Management: Manages Key-Value states for active slots to prevent re-computation.
  • Throughput Gain: Hides the Python CPU dispatch overhead by processing 8 requests in parallel.

3. Aggressive Quantization

  • Used NF4 (NormalFloat 4-bit) quantization via bitsandbytes.
  • Allowed a 14B parameter model (which normally requires ~28GB VRAM) to run comfortably on a 12GB card while maintaining routing accuracy.

📂 Key Code Breakdown

If you are reviewing the code, these are the critical files where the custom engineering logic resides:

1. The Core Innovation (Triton Kernel)

  • kernels/router_kernel.py
    • What it is: The custom GPU kernel written in OpenAI Triton.
    • Key Logic: Implements _fused_router_kernel which fuses the Softmax, Top-K selection, and weight normalization into a single function.
    • Why it matters: This is the specific optimization that removes the memory bandwidth bottleneck found in standard PyTorch routing.

2. The Engine (Scheduler & Dispatch)

  • engine/scheduler.py

    • What it is: The implementation of Continuous Batching logic.
    • Key Logic: A RequestQueue class that manages incoming prompts and a Scheduler that dynamically fills active batch slots.
    • Why it matters: Demonstrates how to move away from static batching to dynamic, asynchronous request processing (similar to vLLM).
  • engine/moe_engine.py

    • What it is: The main inference loop.
    • Key Logic: Manages the KV-Cache, handles the "Prefill" vs "Decode" phases, and executes the generation loop token-by-token.

3. The Integration

  • engine/monkey_patch.py
    • What it is: The "glue" code.
    • Key Logic: Dynamically replaces (hot-swaps) the standard Qwen2MoeRouter in the HuggingFace Transformers library with our custom Triton kernel at runtime.

4. Verification

  • benchmarks/final_benchmark.py
    • What it is: The script used to generate the results table.
    • Key Logic: runs a controlled experiment comparing a serial baseline (Standard API simulation) vs. the Batched Engine.

🛠️ Tech Stack

  • Model: Qwen/Qwen1.5-MoE-A2.7B
  • Kernels: OpenAI Triton
  • Engine Logic: Python, PyTorch
  • Profiling: Nsight Systems, PyTorch Profiler
  • Environment: WSL2 (Ubuntu 22.04)

💻 How to Run

  1. Clone & Install

    git clone https://github.com/yourusername/moe-triton-serve.git
    cd moe-triton-serve
    pip install torch transformers triton bitsandbytes accelerate pandas
  2. Run the Benchmark Comparison

    # Runs the serial baseline vs. the custom engine
    python -m benchmarks.final_benchmark
  3. Run Correctness Tests

    # Verifies the Triton kernel produces the same math as PyTorch
    python tests/test_router.py

🧠 Engineering Learnings

  • The Python Bottleneck: Profiling revealed that even with a fast GPU kernel, Python's dispatch loop (CPU) is the ultimate limit on throughput for small models. A C++ scheduler (like vLLM) would be required for >5x gains.
  • WSL2 Limitations: Direct GPU profiling (CUPTI) is restricted in WSL2, requiring indirect analysis via CPU-side launch metrics.
  • Kernel Fusion: For bandwidth-bound operations (like MoE Routing), fusing operations into one kernel provides larger speedups than matrix math optimizations.

About

High-performance MoE inference engine. Features fused OpenAI Triton kernels, continuous batching, and NF4 quantization. +72% throughput on RTX 3060.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors