Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1100,8 +1100,8 @@ jobs:
run: node scripts/generate-bundle-report.cjs frontend/dist

- name: Upload bundle report
uses: actions/upload-artifact@v7
if: always() && needs.changes.outputs.frontend == 'true'
uses: actions/upload-artifact@v7
with:
name: bundle-report
path: frontend/dist/bundle-report.json
Expand Down
89 changes: 0 additions & 89 deletions ai-engine/utils/self_hosted_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
Phase 3: SGLang vs vLLM benchmark (PortKit prompt shapes)

Issue: #1203 - Self-hosted LLM inference deployment
Issue: #1320 - Enforce Q5_K_M minimum quantization for production inference
"""

import asyncio
import logging
import os
import re
import time
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -22,71 +20,6 @@
logger = logging.getLogger(__name__)


MIN_QUANT_BITS_GGUF = 5
MIN_QUANT_BITS_AWQ = 4
MIN_AWQ_GROUP_SIZE = 128

QUANT_BIT_ORDER = ["Q2_K", "Q3_K", "Q4_K", "Q4_0", "Q5_K", "Q5_K_M", "Q6_K", "Q8_0"]


def _parse_quant_bits(model_name: str) -> Optional[int]:
"""Extract quantization bit depth from a model filename or identifier."""
pattern = re.compile(r"Q([0-9]+)_?K?|Q([0-9]+)\.")
for match in pattern.finditer(model_name):
bits = match.group(1) or match.group(2)
if bits:
try:
return int(bits)
except ValueError:
pass
return None


def check_quantization_floor(
model_name: str,
quant_type: str = "gguf",
awq_group_size: Optional[int] = None,
) -> tuple[bool, str]:
"""
Check if a model meets the minimum quantization floor.

For GGUF: minimum Q5_K_M (5-bit)
For AWQ/EXL2: minimum 4-bit with group_size ≤ 128

Returns (passes, detail_str).
"""
if quant_type in ("gguf", "llama"):
bits = _parse_quant_bits(model_name)
if bits is None:
return True, "quantization bit depth unknown (GGUF)"
if bits < MIN_QUANT_BITS_GGUF:
return False, (
f"model is {bits}-bit; Q5_K_M (5-bit) is the minimum floor for GGUF. "
f"Models below Q5_K_M produce syntax errors in code generation."
)
detail = f"GGUF {bits}-bit (meets Q5_K_M floor)"
return True, detail

elif quant_type in ("awq", "exl2", "gptq"):
bits = _parse_quant_bits(model_name)
if bits is None:
return True, "quantization bit depth unknown (AWQ/EXL2)"
if bits < MIN_QUANT_BITS_AWQ:
return False, (
f"model is {bits}-bit; AWQ/EXL2 requires 4-bit minimum. "
f"Use AWQ 4-bit with group_size ≤ {MIN_AWQ_GROUP_SIZE}."
)
if awq_group_size is not None and awq_group_size > MIN_AWQ_GROUP_SIZE:
return False, (
f"AWQ group_size={awq_group_size} exceeds maximum {MIN_AWQ_GROUP_SIZE}. "
f"For reliable code generation, use group_size ≤ {MIN_AWQ_GROUP_SIZE}."
)
detail = f"AWQ/EXL2 {bits}-bit group_size={awq_group_size or 'default'} (meets floor)"
return True, detail

return True, "quantization type unrecognized, skipping check"


class InferenceProvider(str, Enum):
"""Supported inference providers"""

Expand Down Expand Up @@ -131,10 +64,6 @@ class InferenceConfig:
# vLLM specific
vllm_url: Optional[str] = None

# Quantization metadata (used for floor validation)
model_quant_type: str = "gguf"
awq_group_size: Optional[int] = None

# Performance tuning
max_tokens: int = 4096
temperature: float = 0.1
Expand All @@ -149,14 +78,6 @@ class InferenceConfig:
warmup_requests: int = 1
keep_alive: int = 300 # seconds

def validate_quantization(self) -> tuple[bool, str]:
"""Validate that the configured model meets the quantization floor."""
return check_quantization_floor(
self.model_name,
quant_type=self.model_quant_type,
awq_group_size=self.awq_group_size,
)


@dataclass
class InferenceResult:
Expand Down Expand Up @@ -221,8 +142,6 @@ def _load_config_from_env(self) -> InferenceConfig:
runpod_api_key=os.getenv("RUNPOD_API_KEY"),
sglang_url=os.getenv("SGLANG_URL"),
vllm_url=os.getenv("VLLM_URL"),
model_quant_type=os.getenv("MODEL_QUANT_TYPE", "gguf").lower(),
awq_group_size=int(os.getenv("AWQ_GROUP_SIZE", "128")),
max_tokens=int(os.getenv("MAX_TOKENS", "4096")),
temperature=float(os.getenv("LLM_TEMPERATURE", "0.1")),
timeout=int(os.getenv("INFERENCE_TIMEOUT", "120")),
Expand All @@ -231,14 +150,6 @@ def _load_config_from_env(self) -> InferenceConfig:

def _initialize_client(self):
"""Initialize the appropriate HTTP client based on provider"""
passes, detail = self.config.validate_quantization()
if not passes:
logger.warning(
f"QUANTIZATION FLOOR WARNING for model '{self.config.model_name}': {detail}"
)
else:
logger.info(f"Quantization check for '{self.config.model_name}': {detail}")

if self.config.endpoint_url:
try:
from openai import OpenAI
Expand Down
73 changes: 3 additions & 70 deletions ai_engine/mmsd/TRAINING_REPORT.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,74 +145,7 @@ python3 ai_engine/mmsd/train_portkit_coder.py

---

## 4. Training Recipe (Catastrophic Forgetting Mitigation)

Fine-tuning exclusively on MMSD domain-specific pairs risks **catastrophic forgetting**: the model overwrites general Java/JS knowledge with Minecraft-specific patterns. The fix is a **general programming data mix** (12% of training tokens).

### Why 12%?

- At `r=64` (QLoRA rank), many weights are updated → high risk of forgetting
- 5–15% is the standard range cited in fine-tuning literature
- 12% preserves general reasoning while allowing MMSD specialization

### Mixing Procedure

```python
from datasets import load_dataset, concatenate_datasets

# 1. Load MMSD (validated_pairs.jsonl)
mmsd = load_dataset("json", data_files="validated_pairs.jsonl")["train"] # 1,400 pairs

# 2. Load general code dataset — filter to Java + JavaScript
general = load_dataset("m-a-p/CodeFeedback-Filtered-Instruction", split="train")
general_java_js = general.filter(lambda x: x["lang"] in ["java", "javascript"])

# 3. Sample ~200 general pairs, shuffle deterministically
general_sample = general_java_js.shuffle(seed=42).select(range(200))

# 4. Format general examples to match Stage A prompt template
# (system prompt + user instruction + assistant code response)

# 5. Mix to achieve ~12% general / ~88% MMSD by token count
mixed = concatenate_datasets([mmsd_formatted, general_formatted])
mixed_token_ratio = min(general_tokens / (mmsd_tokens + general_tokens), 0.12)

# 6. Shuffle and split 90/10
mixed = mixed.shuffle(seed=42)
```

### General Code Dataset

| Property | Value |
|----------|-------|
| Dataset | `m-a-p/CodeFeedback-Filtered-Instruction` |
| Languages | Java, JavaScript |
| Sample size | ~200 instruction pairs |
| Prompt template | General code assistant (not PortKit-specific) |
| Caching | `/tmp/portkit_general_code/general_code_sample.jsonl` |

### Expected Effects

| Metric | Without Mix | With Mix (12%) |
|--------|------------|----------------|
| General Java/JS tasks | Degraded | ≤ 2% regression |
| MMSD task quality | Baseline | Improved consistency |
| Edge cases (abstract classes, generics, lambdas) | May degrade | Better handling |

### Verification

To evaluate the effect of the mix on general code tasks:
```bash
python3 ai_engine/mmsd/evaluate.py \
--model alexchapin/portkit-coder-7b-merged \
--baseline Qwen/Qwen2.5-Coder-7B-Instruct \
--eval-data ai_engine/mmsd/data/processed/validated_pairs.jsonl \
--output evaluation_results.json
```

---

## 5. Evaluation
## 4. Evaluation

### Evaluation Script
```bash
Expand All @@ -238,7 +171,7 @@ python3 ai_engine/mmsd/evaluate.py \

---

## 6. Hugging Face Hub Repositories
## 5. Hugging Face Hub Repositories

| Repository | Description | URL |
|------------|-------------|-----|
Expand All @@ -249,7 +182,7 @@ Both repos are set to **private** visibility.

---

## 7. Pipeline Verification
## 6. Pipeline Verification

The training pipeline was verified end-to-end using `Qwen/Qwen2.5-Coder-0.5B` on CPU:

Expand Down
3 changes: 2 additions & 1 deletion ai_engine/mmsd/premium_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
"""

import os
import json
import re
import time
import logging
from typing import Optional
from dataclasses import dataclass
from dataclasses import dataclass, field

import httpx

Expand Down
Loading
Loading