Skip to content

Commit

Permalink
Stop token follow up for staging engine. Properly report stop reason (#…
Browse files Browse the repository at this point in the history
…78)

* wip

* clean

* black

* minor
  • Loading branch information
masahi authored Nov 22, 2023
1 parent 2eb6317 commit 28422a1
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 59 deletions.
24 changes: 16 additions & 8 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
RequestState,
ScopedInferenceEngine,
SequenceOutput,
check_stopping_sequences
check_stopping_sequences,
)
from .model_module import ModelModule, TokenizerModule
from .staging_engine_worker import (
AddRequestsCommand,
CancelRequestCommand,
StopRequestCommand,
ShutdownCommand,
run_generation_loop_worker,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -90,7 +92,9 @@ def add(self, requests: list[Request]):
# wrap the stop sequence with list if necessary
if req.stopping_criteria.stop_sequences:
if isinstance(req.stopping_criteria.stop_sequences, str):
req.stopping_criteria.stop_sequences = [req.stopping_criteria.stop_sequences]
req.stopping_criteria.stop_sequences = [
req.stopping_criteria.stop_sequences
]
assert isinstance(req.stopping_criteria.stop_sequences, list)

# If the request violates the tokenization, this returns None, so skip.
Expand All @@ -107,6 +111,11 @@ def cancel(self, request_id: RequestId):
raise RuntimeError("GenerationLoopWorker process is not running")
self.command_queue.put(CancelRequestCommand(request_id))

def stop_request(self, request_id: RequestId):
if not self._is_ready_to_serve():
raise RuntimeError("GenerationLoopWorker process is not running")
self.command_queue.put(StopRequestCommand(request_id))

def has_pending_requests(self) -> bool:
with self.requests_lock:
return len(self.requests) > 0
Expand Down Expand Up @@ -173,13 +182,12 @@ def step(self) -> InferenceStepResult:
delta = self._decode_last_output(state)
state.output_text += delta

state.output_text, delta, state.is_ended = check_stopping_sequences(state.stopping_criteria,
state.output_text,
delta,
state.is_ended)
# signal workers to stop generation
state.output_text, delta, state.is_ended = check_stopping_sequences(
state.stopping_criteria, state.output_text, delta, state.is_ended
)
# signal workers to stop generation
if state.is_ended:
self.cancel(state.request_id)
self.stop_request(state.request_id)

outputs.append(
RequestOutput(
Expand Down
78 changes: 58 additions & 20 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

logger = logging.getLogger(__name__)


@dataclass
class ShutdownCommand:
pass
Expand All @@ -30,6 +31,11 @@ class CancelRequestCommand:
request_id: RequestId


@dataclass
class StopRequestCommand:
request_id: RequestId


GenerationLoopWorkerCommand = Union[
ShutdownCommand, AddRequestsCommand, CancelRequestCommand
]
Expand Down Expand Up @@ -61,9 +67,12 @@ def __init__(
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
)
self.min_decode_steps = min(
self.max_decode_steps - 1, model_module.engine_config.min_decode_steps
)
self.min_decode_steps = min(self.max_decode_steps - 1, model_module.engine_config.min_decode_steps)
self.prompt_allocate_ratio = model_module.engine_config.prompt_allocate_ratio
assert self.prompt_allocate_ratio >= 1.0

Expand All @@ -72,6 +81,7 @@ def __init__(
self.has_new_requests = Condition(lock=self.queue_lock)

self.cancelled_requests = list[RequestState]()
self.stopped_requests = list[RequestState]()

self.current_batch = dict[RequestId, RequestState]()

Expand All @@ -81,28 +91,40 @@ def add(self, request_states: list[RequestState]):
# cancel them instead.
valid_states = []
for request_state in request_states:
if request_state.validation_err is not None or request_state.prompt_len >= self.max_context_length:
if (
request_state.validation_err is not None
or request_state.prompt_len >= self.max_context_length
):
self.cancelled_requests.append(request_state)
else:
valid_states.append(request_state)

self.queue.extend(valid_states)
self.has_new_requests.notify_all()

def cancel(self, request_id: RequestId):
with self.queue_lock:
queue_index_to_delete = None
for i, state in enumerate(self.queue):
if state.request_id == request_id:
queue_index_to_delete = i
self.cancelled_requests.append(state)
break
def _get_request_state(self, request_id: RequestId) -> Optional[RequestState]:
for state in self.queue:
if state.request_id == request_id:
return state

if queue_index_to_delete is not None:
del self.queue[queue_index_to_delete]
return None

def _cacnel_or_stop_request(
self, request_id: RequestId, requests: list[RequestState]
):
with self.queue_lock:
state = self._get_request_state(request_id)
if state:
del state

if request_id in self.current_batch:
self.cancelled_requests.append(self.current_batch[request_id])
requests.append(self.current_batch[request_id])

def cancel_request(self, request_id: RequestId):
self._cacnel_or_stop_request(request_id, self.cancelled_requests)

def stop_request(self, request_id: RequestId):
self._cacnel_or_stop_request(request_id, self.stopped_requests)

def wait_for_request(self, timeout_seconds=None) -> bool:
with self.queue_lock:
Expand Down Expand Up @@ -138,6 +160,20 @@ def step(self) -> GenerationLoopWorkerOutput:
)
self._remove_request_from_batch(state.request_id)

for state in self.stopped_requests:
outputs.append(
SequenceGenerationOutput(
# TODO: support multi-sequence
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Stop,
)
)
if state.request_id in self.current_batch:
self._remove_request_from_batch(state.request_id)

self.stopped_requests.clear()

for state in self.cancelled_requests:
err = None
if state.validation_err:
Expand All @@ -149,7 +185,7 @@ def step(self) -> GenerationLoopWorkerOutput:
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error = err
error=err,
)
)
if state.request_id in self.current_batch:
Expand Down Expand Up @@ -307,13 +343,13 @@ def _has_request_to_process(self) -> bool:
return self.queue or self.current_batch

def _should_stop_by_length(self, state: RequestState) -> bool:
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
# TODO: currently, we simply return true for both stopping reasons.
# in the future, we can differentiate these two.
# this include prompt tokens and gen tokens so far
num_context_tokens = len(state.token_ids)
num_context_tokens = len(state.token_ids)
if num_context_tokens >= self.model_artifact_config.max_context_length:
return True
num_gen_tokens = num_context_tokens - state.prompt_len
num_gen_tokens = num_context_tokens - state.prompt_len
if num_gen_tokens >= state.stopping_criteria.max_tokens:
return True
return False
Expand Down Expand Up @@ -366,7 +402,9 @@ def handle_command():
elif isinstance(cmd, AddRequestsCommand):
worker.add(cmd.request_states)
elif isinstance(cmd, CancelRequestCommand):
worker.cancel(cmd.request_id)
worker.cancel_request(cmd.request_id)
elif isinstance(cmd, StopRequestCommand):
worker.stop_request(cmd.request_id)
else:
logger.error("Unknown command type %s", type(cmd))
break
Expand Down
3 changes: 1 addition & 2 deletions serve/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test(args: argparse.Namespace):
engine_config = get_engine_config({
"use_staging_engine": args.use_staging_engine,
"max_num_sequences": args.max_num_sequences,
"max_input_len": args.max_input_len,
"max_input_len": args.max_input_len,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"prompt_allocate_ratio": args.prompt_allocate_ratio
Expand Down Expand Up @@ -119,7 +119,6 @@ def test(args: argparse.Namespace):
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="dist")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--max-input-len", type=int, default=512)
parser.add_argument("--max-num-sequences", type=int, default=8)
parser.add_argument("--max-output-len", type=int, default=20)
Expand Down
56 changes: 27 additions & 29 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule

def create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,

):
engine_config = get_engine_config({
"use_staging_engine": use_staging_engine,
"max_num_sequences": max_num_sequences,
"max_input_len": max_input_len,
"max_num_sequences": max_num_sequences,
"max_input_len": max_input_len,
# Use defaults for "min_decode_steps", "max_decode_steps", "prompt_allocate_ratio"
})

Expand Down Expand Up @@ -57,27 +57,27 @@ def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos):
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
),
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
max_tokens=max_tokens,
stop_sequences=stop
),
),
debug_options = DebugOptions(ignore_eos = ignore_eos)
)

def test_max_tokens(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
ignore_eos=False
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
)

Expand All @@ -91,7 +91,7 @@ def test_max_tokens(
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
Expand All @@ -103,17 +103,17 @@ def test_max_tokens(


def test_ignore_eos(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
):
prompt = "hi"
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
)
s = 113
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_ignore_eos(
def test_stop(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
):
Expand All @@ -167,12 +167,10 @@ def test_stop(
seq = res.sequences[0]
req_id = int(res.request_id)
if seq.is_finished:
# TODO: Currently staging engine returns FinishReason.Cancelled.
# This needs to be fixed.
#assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}"
assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}"
assert not seq.delta
gen_txt = generated[req_id]

# stop token should appear only once in the gen text.
found = sum([gen_txt.count(str_stop) for str_stop in requests[req_id].stopping_criteria.stop_sequences])
assert found == 1, f"{gen_txt!r}, matches: {found}"
Expand All @@ -186,10 +184,10 @@ def test_stop(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="../../../dist")
parser.add_argument("--artifact-path", type=str, default="dist")
args = parser.parse_args()
model_artifact_path = os.path.join(args.artifact_path, args.local_id)

test_max_tokens(model_artifact_path, use_staging_engine=True)
test_max_tokens(model_artifact_path, use_staging_engine=False)
test_ignore_eos(model_artifact_path, use_staging_engine=True)
Expand Down

0 comments on commit 28422a1

Please sign in to comment.