Skip to content

Commit

Permalink
feat(model): Add LLM benchmarks code (#808)
Browse files Browse the repository at this point in the history
- Add LLM benchmarks code
- Add LLM metrics
  • Loading branch information
Aries-ckt authored Nov 20, 2023
2 parents 29ca172 + 0fff10b commit ecc5d5d
Show file tree
Hide file tree
Showing 14 changed files with 1,012 additions and 15 deletions.
214 changes: 214 additions & 0 deletions docker/examples/benchmarks/benchmarks_llm_11k_prompt.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pilot/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding:utf-8 -*-

import os
from functools import cache

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
Expand All @@ -22,6 +23,7 @@
os.chdir(new_directory)


@cache
def get_device() -> str:
try:
import torch
Expand Down
78 changes: 78 additions & 0 deletions pilot/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from enum import Enum
from typing import TypedDict, Optional, Dict, List, Any

from dataclasses import dataclass, asdict
import time
from datetime import datetime
from pilot.utils.parameter_utils import ParameterDescription

Expand Down Expand Up @@ -47,13 +49,89 @@ class WorkerApplyType(str, Enum):
UPDATE_PARAMS = "update_params"


@dataclass
class ModelInferenceMetrics:
"""A class to represent metrics for assessing the inference performance of a LLM."""

start_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the model inference starts."""

end_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the model inference ends."""

current_time_ms: Optional[int] = None
"""The current timestamp (in milliseconds) when the model inference return partially output(stream)."""

first_token_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the first token is generated."""

first_completion_time_ms: Optional[int] = None
"""The timestamp (in milliseconds) when the first completion is generated."""

first_completion_tokens: Optional[int] = None
"""The number of tokens when the first completion is generated."""

prompt_tokens: Optional[int] = None
"""The number of tokens in the input prompt."""

completion_tokens: Optional[int] = None
"""The number of tokens in the generated completion."""

total_tokens: Optional[int] = None
"""The total number of tokens (prompt plus completion)."""

speed_per_second: Optional[float] = None
"""The average number of tokens generated per second."""

@staticmethod
def create_metrics(
last_metrics: Optional["ModelInferenceMetrics"] = None,
) -> "ModelInferenceMetrics":
start_time_ms = last_metrics.start_time_ms if last_metrics else None
first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None
first_completion_time_ms = (
last_metrics.first_completion_time_ms if last_metrics else None
)
first_completion_tokens = (
last_metrics.first_completion_tokens if last_metrics else None
)
prompt_tokens = last_metrics.prompt_tokens if last_metrics else None
completion_tokens = last_metrics.completion_tokens if last_metrics else None
total_tokens = last_metrics.total_tokens if last_metrics else None
speed_per_second = last_metrics.speed_per_second if last_metrics else None

if not start_time_ms:
start_time_ms = time.time_ns() // 1_000_000
current_time_ms = time.time_ns() // 1_000_000
end_time_ms = current_time_ms

return ModelInferenceMetrics(
start_time_ms=start_time_ms,
end_time_ms=end_time_ms,
current_time_ms=current_time_ms,
first_token_time_ms=first_token_time_ms,
first_completion_time_ms=first_completion_time_ms,
first_completion_tokens=first_completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
speed_per_second=speed_per_second,
)

def to_dict(self) -> Dict:
return asdict(self)


@dataclass
class ModelOutput:
text: str
error_code: int
model_context: Dict = None
finish_reason: str = None
usage: Dict[str, Any] = None
metrics: Optional[ModelInferenceMetrics] = None

"""Some metrics for model inference"""

def to_dict(self) -> Dict:
return asdict(self)
Expand Down
2 changes: 2 additions & 0 deletions pilot/model/cluster/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class PromptRequest(BaseModel):
context_len: int = None
echo: bool = True
span_id: str = None
metrics: bool = False
"""Whether to return metrics of inference"""


class EmbeddingsRequest(BaseModel):
Expand Down
105 changes: 98 additions & 7 deletions pilot/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import logging

from typing import Dict, Iterator, List, Optional
import time
import traceback

from pilot.configs.model_config import get_device
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
from pilot.model.base import ModelOutput
from pilot.model.base import ModelOutput, ModelInferenceMetrics
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
Expand Down Expand Up @@ -144,14 +147,29 @@ def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
)

previous_response = ""
last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True

context_len = params.get("context_len") or self.context_len
for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
(
model_output,
incremental_output,
output_str,
current_metrics,
) = self._handle_output(
output,
previous_response,
model_context,
last_metrics,
is_first_generate,
)
if is_first_generate:
is_first_generate = False
previous_response = output_str
last_metrics = current_metrics
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
Expand Down Expand Up @@ -191,13 +209,28 @@ async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
previous_response = ""
context_len = params.get("context_len") or self.context_len

last_metrics = ModelInferenceMetrics.create_metrics()
is_first_generate = True
async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
(
model_output,
incremental_output,
output_str,
current_metrics,
) = self._handle_output(
output,
previous_response,
model_context,
last_metrics,
is_first_generate,
)
if is_first_generate:
is_first_generate = False

previous_response = output_str
last_metrics = current_metrics
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
Expand Down Expand Up @@ -262,7 +295,14 @@ def _prepare_generate_stream(self, params: Dict, span_operation_name: str):

return params, model_context, generate_stream_func, model_span

def _handle_output(self, output, previous_response, model_context):
def _handle_output(
self,
output,
previous_response,
model_context,
last_metrics: ModelInferenceMetrics,
is_first_generate: bool,
):
finish_reason = None
usage = None
if isinstance(output, dict):
Expand All @@ -273,14 +313,17 @@ def _handle_output(self, output, previous_response, model_context):
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)

metrics = _new_metrics_from_model_output(last_metrics, is_first_generate, usage)
model_output = ModelOutput(
text=output,
error_code=0,
model_context=model_context,
finish_reason=finish_reason,
usage=usage,
metrics=metrics,
)
return model_output, incremental_output, output
return model_output, incremental_output, output, metrics

def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
Expand All @@ -289,6 +332,8 @@ def _handle_exception(self, e):
text="**GPU OutOfMemory, Please Refresh.**", error_code=1
)
else:
msg = traceback.format_exc()
logger.error(f"Model inference error, detail: {msg}")
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=1,
Expand All @@ -310,3 +355,49 @@ def _parse_model_max_length(model, tokenizer) -> Optional[int]:
return model_config.max_position_embeddings
except Exception:
return None


def _new_metrics_from_model_output(
last_metric: ModelInferenceMetrics,
is_first_generate: bool,
usage: Optional[Dict] = None,
) -> ModelInferenceMetrics:
metrics = ModelInferenceMetrics.create_metrics(last_metric)
if is_first_generate:
logger.info(f"is_first_generate, usage: {usage}")
metrics.first_completion_time_ms = time.time_ns() // 1_000_000

if not usage or not isinstance(usage, dict):
return metrics
prompt_tokens = usage.get("prompt_tokens")
completion_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")

if prompt_tokens is None:
prompt_tokens = metrics.prompt_tokens
if completion_tokens is None:
completion_tokens = metrics.completion_tokens
if total_tokens is None:
total_tokens = metrics.total_tokens

if is_first_generate and (completion_tokens is not None):
# completion_tokens == 0 is prefill
metrics.first_completion_tokens = completion_tokens
if completion_tokens == 1:
metrics.first_token_time_ms = metrics.first_completion_time_ms

if prompt_tokens:
metrics.prompt_tokens = prompt_tokens
if completion_tokens:
metrics.completion_tokens = completion_tokens
if total_tokens:
metrics.total_tokens = total_tokens
elif prompt_tokens and completion_tokens:
total_tokens = prompt_tokens + completion_tokens
metrics.total_tokens = total_tokens

if total_tokens:
# time cost(seconds)
duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
metrics.speed_per_second = total_tokens / duration
return metrics
7 changes: 6 additions & 1 deletion pilot/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,11 +1000,16 @@ def run_worker_manager(
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
**kwargs,
):
global worker_manager

worker_params: ModelWorkerParameters = _parse_worker_params(
model_name=model_name, model_path=model_path, standalone=standalone, port=port
model_name=model_name,
model_path=model_path,
standalone=standalone,
port=port,
**kwargs,
)

setup_logging(
Expand Down
45 changes: 42 additions & 3 deletions pilot/model/llm_out/vllm_llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Dict
import os
from vllm import AsyncLLMEngine
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams


_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"


async def generate_stream(
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
):
Expand Down Expand Up @@ -37,20 +41,55 @@ async def generate_stream(
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
gen_params = {
"stop": list(stop),
"ignore_eos": False,
}
prompt_token_ids = None
if _IS_BENCHMARK:
gen_params["stop"] = []
gen_params["ignore_eos"] = True
prompt_len = context_len - max_new_tokens - 2
prompt_token_ids = tokenizer([prompt]).input_ids[0]
prompt_token_ids = prompt_token_ids[-prompt_len:]
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
stop=list(stop),
max_tokens=max_new_tokens,
**gen_params
)

results_generator = model.generate(
prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids
)
results_generator = model.generate(prompt, sampling_params, request_id)
async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [prompt + output.text for output in request_output.outputs]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
yield {"text": text_outputs, "error_code": 0, "usage": {}}

# Note: usage is not supported yet
prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
finish_reason = (
request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs]
)
yield {
"text": text_outputs,
"error_code": 0,
"usage": usage,
"finish_reason": finish_reason,
}
Loading

0 comments on commit ecc5d5d

Please sign in to comment.