Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded integration tests #995

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/pkgci_shark_ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,49 @@ jobs:
name: smoke-test-${{ matrix.name }}
path: smoke-test-${{ matrix.name }}.xml

sharded_tests:
name: "Sharded smoke tests (${{ matrix.name }})"
runs-on: ${{ matrix.runs-on }}
strategy:
fail-fast: false
matrix:
include:
- name: amdgpu_mi300_gfx942_tp4
runs-on: linux-mi300-4gpu-ossci-nod-ai
test_device: gfx942_tp4
python-version: 3.11
defaults:
run:
shell: bash
env:
PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages
VENV_DIR: ${{ github.workspace }}/.venv
steps:
- name: Run rocminfo
if: contains(matrix.test_device, 'gfx')
run: rocminfo
- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: "Set up environment and install PkgCI Artifacts"
uses: ./.github/actions/pkgci-setup
with:
python-version: ${{matrix.python-version}}
artifact-run-id: ${{ inputs.artifact_run_id }}
- name: Run LLM Sharded Smoke Test
run: |
source ${VENV_DIR}/bin/activate
pytest -v --test_device=${{ matrix.test_device }} \
--junitxml=sharded-smoke-test-${{ matrix.name }}.xml \
app_tests/integration_tests/llm/shortfin/sharded_model_test.py \
--log-cli-level=INFO
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v4
with:
name: sharded-smoke-test-${{ matrix.name }}
path: sharded-smoke-test-${{ matrix.name }}.xml


integration_test:
name: "Integration Test (${{ matrix.name }})"
runs-on: ${{ matrix.runs-on }}
Expand Down
27 changes: 27 additions & 0 deletions app_tests/integration_tests/llm/device_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ class DeviceSettings:
server_flags=("--device=hip",),
)

GFX942_TP4 = DeviceSettings(
compile_flags=(
"--iree-hip-target=gfx942",
"--iree-hal-target-device=hip[0]",
"--iree-hal-target-device=hip[1]",
"--iree-hal-target-device=hip[2]",
"--iree-hal-target-device=hip[3]",
),
server_flags=(
"--device=hip",
"--device_ids",
"0",
"1",
"2",
"3",
),
# server_flags=(
# "--device=hip",
# "--device_ids",
# "0",
# "0",
# "0",
# "0",
# ), # temporarily testing on all 4 device actually being the same device
)

GFX90A = DeviceSettings(
compile_flags=(
"--iree-hal-target-backends=rocm",
Expand All @@ -34,6 +60,7 @@ class DeviceSettings:

table = {
"gfx942": GFX942,
"gfx942_tp4": GFX942_TP4,
"gfx90a": GFX90A,
"host": CPU,
"hostcpu": CPU,
Expand Down
75 changes: 74 additions & 1 deletion app_tests/integration_tests/llm/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class ModelConfig:
repo_id: Optional[str] = None
local_path: Optional[Path] = None
azure_config: Optional[AzureConfig] = None
tensor_parallelism_size: Optional[
int
] = None # Number of shards for tensor parallelism

def __post_init__(self):
if self.source == ModelSource.HUGGINGFACE_FROM_GGUF:
Expand All @@ -117,12 +120,13 @@ def __post_init__(self):
class ModelArtifacts:
"""Container for all paths related to model artifacts."""

weights_path: Path
weights_path: Path # Main weights file (unranked for sharded models)
tokenizer_path: Path
mlir_path: Path
vmfb_path: Path
config_path: Path
model_config: ModelConfig # config that was originally used to generate these artifacts
shard_paths: Optional[list[Path]] = None # Paths to sharded weight files (rank0-N)


class ModelStageManager:
Expand Down Expand Up @@ -305,6 +309,42 @@ def prepare_tokenizer(self) -> Path:

return tokenizer_path

def shard_model(self, weights_path: Path) -> Tuple[Path, list[Path]]:
"""Shards model using tensor parallelism if configured."""
if not self.config.tensor_parallelism_size:
return weights_path, None

logger.info(
f"Sharding model with tensor parallelism size {self.config.tensor_parallelism_size}"
)

# Determine output paths
base_name = weights_path.stem
output_base = self.model_dir / f"{base_name}.sharded"
output_irpa = output_base.with_suffix(".irpa")

# Run sharding script
subprocess.run(
[
"python",
"-m",
"sharktank.examples.sharding.shard_llm_dataset",
f"--{weights_path.suffix.strip('.')}-file={weights_path}",
f"--output-irpa={output_irpa}",
f"--tensor-parallelism-size={self.config.tensor_parallelism_size}",
],
check=True,
)

# Collect paths to all shards
shard_paths = [
output_base.with_suffix(f".rank{i}.irpa")
for i in range(self.config.tensor_parallelism_size)
]

logger.info(f"Model successfully sharded into {len(shard_paths)} shards")
return output_irpa, shard_paths

def export_model(self, weights_path: Path) -> Tuple[Path, Path]:
"""Exports model to MLIR format."""
bs_string = ",".join(map(str, self.config.batch_sizes))
Expand All @@ -318,6 +358,10 @@ def export_model(self, weights_path: Path) -> Tuple[Path, Path]:
f" Batch Sizes: {bs_string}"
)

# For sharded models, we use the unranked irpa file
if self.config.tensor_parallelism_size:
weights_path = weights_path.with_suffix(".irpa")

subprocess.run(
[
"python",
Expand Down Expand Up @@ -346,6 +390,7 @@ def compile_model(self, mlir_path: Path) -> Path:
"-o",
str(vmfb_path),
]

compile_command.extend(self.config.device_settings.compile_flags)

subprocess.run(compile_command, check=True)
Expand Down Expand Up @@ -377,6 +422,11 @@ def process_model(self, config: ModelConfig) -> ModelArtifacts:

tokenizer_path = manager.prepare_tokenizer()

# Stage 1.5: Shard model if tensor parallelism is configured
shard_paths = None
if config.tensor_parallelism_size:
weights_path, shard_paths = manager.shard_model(weights_path)

# Stage 2: Export model (fresh every time)
mlir_path, config_path = manager.export_model(weights_path)

Expand All @@ -390,6 +440,7 @@ def process_model(self, config: ModelConfig) -> ModelArtifacts:
vmfb_path=vmfb_path,
config_path=config_path,
model_config=config,
shard_paths=shard_paths,
)


Expand Down Expand Up @@ -452,3 +503,25 @@ def process_model(self, config: ModelConfig) -> ModelArtifacts:
batch_sizes=(1, 4),
device_settings=None,
)

# Sharded version of tinystories for smoke testing
TEST_MODELS["tinystories_tp4"] = ModelConfig(
source=ModelSource.HUGGINGFACE_FROM_SAFETENSORS,
dataset_name="Mxode/TinyStories-LLaMA2-25M-256h-4l-GQA",
model_file="model.irpa", # This will be the final converted file name
tokenizer_id="Mxode/TinyStories-LLaMA2-25M-256h-4l-GQA",
batch_sizes=(4,), # Fixed batch size of 4 for testing
device_settings=None,
tensor_parallelism_size=4, # 4-way sharding for smoke testing
)

# Example of a sharded model configuration
TEST_MODELS["llama3.1_405b"] = ModelConfig(
source=ModelSource.HUGGINGFACE_FROM_GGUF,
repo_id="meta-llama/Llama-3.1-405B", # Note: This is a placeholder, actual repo may differ
model_file="llama3.1-405b.f16.gguf",
tokenizer_id="meta-llama/Llama-3.1-405B",
batch_sizes=(1, 4),
device_settings=None,
tensor_parallelism_size=8, # Required for 405B model
)
16 changes: 13 additions & 3 deletions app_tests/integration_tests/llm/server_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from .device_settings import DeviceSettings
from .model_management import ModelArtifacts

from logging import getLogger

logger = getLogger(__name__)


@dataclass
class ServerConfig:
Expand Down Expand Up @@ -72,16 +76,20 @@ def start(self) -> None:
f"--tokenizer_json={self.config.artifacts.tokenizer_path}",
f"--model_config={self.config.artifacts.config_path}",
f"--vmfb={self.config.artifacts.vmfb_path}",
f"--parameters={self.config.artifacts.weights_path}",
f"--parameters",
str(self.config.artifacts.weights_path),
*(str(path) for path in (self.config.artifacts.shard_paths or [])),
f"--port={self.port}",
f"--prefix_sharing_algorithm={self.config.prefix_sharing_algorithm}",
]
cmd.extend(self.config.device_settings.server_flags)

logger.info("Starting server with command: %s", " ".join(cmd))

self.process = subprocess.Popen(cmd)
self.wait_for_ready()

def wait_for_ready(self, timeout: int = 30) -> None:
def wait_for_ready(self, timeout: int = 180) -> None:
"""Waits for server to be ready and responding to health checks."""
if self.port is None:
raise RuntimeError("Server hasn't been started")
Expand All @@ -91,7 +99,9 @@ def wait_for_ready(self, timeout: int = 30) -> None:
try:
requests.get(f"http://localhost:{self.port}/health")
return
except requests.exceptions.ConnectionError:
except requests.exceptions.ConnectionError as e:
logger.info("While attempting to server,")
logger.info("Encountered connection error %s", e)
time.sleep(1)
raise TimeoutError(f"Server failed to start within {timeout} seconds")

Expand Down
2 changes: 2 additions & 0 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def model_artifacts(tmp_path_factory, request, test_device):
mlir_path=model_dir / "model.mlir",
vmfb_path=model_dir / "model.vmfb",
config_path=model_dir / "config.json",
model_config=model_config,
shard_paths=None,
)

# Process model and create artifacts
Expand Down
102 changes: 102 additions & 0 deletions app_tests/integration_tests/llm/shortfin/sharded_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Test for sharded model serving with concurrent requests.
Tests a 4-way sharded model with batch size 4 handling 3 concurrent requests.
"""

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import logging
import pytest
import requests
import uuid

logger = logging.getLogger(__name__)

from ..model_management import AccuracyValidationException, ModelConfig, ModelSource

pytestmark = pytest.mark.parametrize(
"model_artifacts,server",
[
["tinystories_tp4", {"prefix_sharing": "none"}],
],
indirect=True,
)

# Test prompt and expected response pattern for tinystories
PROMPT = "Once upon a time, there was a"
EXPECTED_PATTERN = "little" # Common word in children's stories


class TestShardedModelServer:
"""Test suite for sharded model server functionality."""

def test_concurrent_generation_sharded(self, server: tuple[Any, int]) -> None:
"""Tests concurrent text generation with a sharded model.

Uses 3 concurrent requests to test the server's ability to handle
multiple requests with a 4-way sharded model configured for batch size 4.

Args:
server: Tuple of (process, port) from server fixture
"""
process, port = server
assert process.poll() is None, "Server process terminated unexpectedly"

concurrent_requests = 3 # Fixed number of concurrent requests

with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
futures = [
executor.submit(self._generate, PROMPT, port)
for _ in range(concurrent_requests)
]

for future in as_completed(futures):
response = future.result()
if EXPECTED_PATTERN not in response:
raise AccuracyValidationException(
expected=f"...{EXPECTED_PATTERN}...",
actual=response,
message=f"Generation did not contain expected pattern.\nExpected to contain: {EXPECTED_PATTERN}\nActual response: {response}",
)

def _generate(self, prompt: str, port: int) -> str:
"""Helper method to make generation request to server.

Args:
prompt: Input text prompt
port: Server port number

Returns:
Generated text response

Raises:
requests.exceptions.RequestException: If request fails
AccuracyValidationException: If response format is invalid
"""
payload = {
"text": prompt,
"sampling_params": {
"max_completion_tokens": 15,
"temperature": 0.0, # Use greedy sampling for deterministic output
},
"rid": uuid.uuid4().hex,
"stream": False,
}

response = requests.post(
f"http://localhost:{port}/generate",
headers={"Content-Type": "application/json"},
json=payload,
timeout=30,
)
response.raise_for_status()

data = response.text
if not data.startswith("data: "):
raise AccuracyValidationException(
expected="Response starting with 'data: '",
actual=data,
message=f"Invalid response format.\nExpected format starting with 'data: '\nActual response: {data}",
)

return data[6:].rstrip("\n")
Loading