diff --git a/autotest/interface/restful/test_restful_chat_func.py b/autotest/interface/restful/test_restful_chat_func.py index b27243959..0aa5d7b77 100644 --- a/autotest/interface/restful/test_restful_chat_func.py +++ b/autotest/interface/restful/test_restful_chat_func.py @@ -1,11 +1,8 @@ -import random from concurrent.futures import ThreadPoolExecutor -from random import randint import pytest from tqdm import tqdm from utils.restful_return_check import (assert_chat_completions_batch_return, assert_chat_completions_stream_return, - assert_chat_interactive_batch_return, assert_chat_interactive_stream_return, get_repeat_times) from lmdeploy.serve.openai.api_client import APIClient, get_model_list @@ -66,13 +63,6 @@ def process_one(question): msg = [dict(role='user', content=question)] - data = api_client.chat_interactive_v1(msg, - session_id=randint(1, 100), - repetition_penalty=1.02, - request_output_len=224) - for item in data: - pass - data = api_client.chat_completions_v1(model=model_name, messages=msg, repetition_penalty=1.02, @@ -88,14 +78,6 @@ def process_one(question): for response in tqdm(executor.map(process_one, ['你是谁'] * 500)): continue - def test_issue1324_illegal_topk(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself', top_k=-1): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_k `-1` cannot be a negative integer.' - assert output.get('object') == 'error' - @pytest.mark.order(8) @pytest.mark.turbomind @@ -614,455 +596,3 @@ def test_logprobs_streaming(self): length = api_client.encode(response, add_bos=False)[1] assert outputList[-1].get('choices')[0].get('finish_reason') == 'length' assert length == 5 or length == 6 - - -@pytest.mark.order(8) -@pytest.mark.turbomind -@pytest.mark.pytorch -@pytest.mark.flaky(reruns=2) -class TestRestfulInterfaceChatInteractive: - - def test_return_info_with_prompt(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself', temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - - def test_return_info_with_messegae(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt=[{ - 'role': 'user', - 'content': 'Hi, pls intro yourself' - }], - temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - - def test_return_info_with_prompt_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself', stream=True, temperature=0.01): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 1) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - - def test_return_info_with_messegae_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt=[{ - 'role': 'user', - 'content': 'Hi, pls intro yourself' - }], - stream=True, - temperature=0.01): - outputList.append(output) - - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 1) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - - def test_single_stopword(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Shanghai is', stop=' is', temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - assert ' is' not in output.get('text') - assert output.get('finish_reason') == 'stop' - - def test_single_stopword_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Shanghai is', stop=' is', stream=True, temperature=0.01): - outputList.append(output) - - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - assert ' to' not in outputList[index].get('text') - assert output.get('finish_reason') == 'stop' - - def test_array_stopwords(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Shanghai is', stop=[' is', '上海', ' to'], temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - assert ' is' not in output.get('text') - assert ' 上海' not in output.get('text') - assert ' to' not in output.get('text') - assert output.get('finish_reason') == 'stop' - - def test_array_stopwords_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Shanghai is', - stop=[' is', '上海', ' to'], - stream=True, - temperature=0.01): - outputList.append(output) - - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - assert ' is' not in outputList[index].get('text') - assert '上海' not in outputList[index].get('text') - assert ' to' not in outputList[index].get('text') - assert output.get('finish_reason') == 'stop' - - def test_special_words(self): - message = '<|im_start|>system\n当开启工具以及代码时,根据需求选择合适的工具进行调用\n' + \ - '<|im_end|><|im_start|>system name=<|interpreter|>\n你现在已经' + \ - '能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。当你向 python ' + \ - '发送含有 Python >代码的消息时,它将在该环境中执行。这个工具适用于多种场景,' + \ - '如数据分析或处理(包括数据操作、统计分析、图表绘制),复杂的计算问题(解决数学和物理' + \ - '难题),编程示例(理解编程概念或特性),文本处理和分析(比如文本解析和自然语言处理),机器学习和数据科学(用于' + \ - '展示模型训练和数据可视化),以及文件操作和数据导入(处理CSV、JSON等格式的文件)。<|im_end|>\n' + \ - '<|im_start|>user\n设 $L$ 为圆周$x^2+y^2=2x$,计算曲线积分:$I=\\int_L' + \ - '{x\\mathrm{d}s}=$<|im_end|>\n<|im_start|>assistant' - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt=message, skip_special_tokens=False, temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - assert '<|action_start|><|interpreter|>' in output.get('text') - - for output in api_client.chat_interactive_v1(prompt=message, skip_special_tokens=True, temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - assert '<|action_start|><|interpreter|>' not in output.get('text') - - def test_minimum_repetition_penalty(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Shanghai is', - repetition_penalty=0.1, - temperature=0.01, - request_output_len=512): - continue - assert_chat_interactive_batch_return(output) - assert get_repeat_times(output.get('text'), 'is a name') > 5 or get_repeat_times( - output.get('text'), 'Shanghai is') > 5 - - def test_minimum_repetition_penalty_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Shanghai is', - repetition_penalty=0.1, - temperature=0.01, - stream=True, - request_output_len=512): - outputList.append(output) - - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - response = '' - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - response += outputList[index].get('text') - assert get_repeat_times(response, 'is a name') > 5 or get_repeat_times(response, 'Shanghai is') > 5 - - def test_repetition_penalty_bigger_than_1(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Shanghai is', - repetition_penalty=1.2, - temperature=0.01, - request_output_len=512): - continue - assert_chat_interactive_batch_return(output) - - def test_repetition_penalty_bigger_than_1_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Shanghai is', - repetition_penalty=1.2, - stream=True, - temperature=0.01, - request_output_len=512): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - - def test_multiple_rounds(self): - api_client = APIClient(BASE_URL) - history = 0 - session_id = random.randint(0, 100000) - for i in range(3): - for output in api_client.chat_interactive_v1(prompt='Shanghai is', - temperature=0.01, - interactive_mode=True, - session_id=session_id): - continue - assert_chat_interactive_batch_return(output) - assert output.get('history_tokens') == history - history += output.get('input_tokens') + output.get('tokens') - - def test_multiple_rounds_streaming(self): - api_client = APIClient(BASE_URL) - history = 0 - session_id = random.randint(0, 100000) - for i in range(3): - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself', - stream=True, - temperature=0.01, - interactive_mode=True, - session_id=session_id): - outputList.append(output) - print(outputList) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - assert outputList[-1].get('history_tokens') == history - history += outputList[-1].get('input_tokens') + outputList[-1].get('tokens') - - def test_minimum_topp(self): - api_client = APIClient(BASE_URL) - outputList = [] - for i in range(3): - for output in api_client.chat_interactive_v1(prompt='Shanghai is', top_p=0.01, request_output_len=10): - continue - assert_chat_interactive_batch_return(output) - outputList.append(output) - assert outputList[0] == outputList[1] - assert outputList[1] == outputList[2] - - def test_minimum_topp_streaming(self): - api_client = APIClient(BASE_URL) - model_name = api_client.available_models[0] - responseList = [] - for i in range(3): - outputList = [] - response = '' - for output in api_client.chat_interactive_v1(model=model_name, - prompt='Hi, pls intro yourself', - stream=True, - top_p=0.01, - request_output_len=10): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - response += outputList[index].get('text') - responseList.append(response) - assert responseList[0] == responseList[1] or responseList[1] == responseList[2] - - def test_minimum_topk(self): - api_client = APIClient(BASE_URL) - outputList = [] - for i in range(3): - for output in api_client.chat_interactive_v1(prompt='Shanghai is', top_k=1, request_output_len=10): - continue - assert_chat_interactive_batch_return(output) - outputList.append(output) - assert outputList[0] == outputList[1] - assert outputList[1] == outputList[2] - - def test_minimum_topk_streaming(self): - api_client = APIClient(BASE_URL) - model_name = api_client.available_models[0] - responseList = [] - for i in range(3): - outputList = [] - response = '' - for output in api_client.chat_interactive_v1(model=model_name, - prompt='Hi, pls intro yourself', - stream=True, - top_k=1, - request_output_len=10): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - response += outputList[index].get('text') - responseList.append(response) - assert responseList[0] == responseList[1] - assert responseList[1] == responseList[2] - - def test_mutilple_times_response_should_not_same(self): - api_client = APIClient(BASE_URL) - outputList = [] - for i in range(3): - for output in api_client.chat_interactive_v1(prompt='Shanghai is', request_output_len=100): - continue - assert_chat_interactive_batch_return(output) - outputList.append(output) - assert outputList[0] != outputList[1] or outputList[1] != outputList[2] - - def test_mutilple_times_response_should_not_same_streaming(self): - api_client = APIClient(BASE_URL) - model_name = api_client.available_models[0] - responseList = [] - for i in range(3): - outputList = [] - response = '' - for output in api_client.chat_interactive_v1(model=model_name, - prompt='Hi, pls intro yourself', - stream=True, - request_output_len=100): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - response += outputList[index].get('text') - responseList.append(response) - assert responseList[0] != responseList[1] or responseList[1] != responseList[2] - - def test_longtext_input(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself' * 100000, temperature=0.01): - continue - assert output.get('finish_reason') == 'length' - assert output.get('text') == '' - - def test_longtext_input_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself' * 100000, - stream=True, - temperature=0.01): - outputList.append(output) - assert outputList[0].get('finish_reason') == 'length', outputList - assert outputList[0].get('text') == '' - assert len(outputList) == 1 - - def test_ignore_eos(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi, what is your name?', - ignore_eos=True, - request_output_len=100, - temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - assert output.get('tokens') == 100 or output.get('tokens') == 101 - assert output.get('finish_reason') == 'length' - - def test_ignore_eos_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Hi, what is your name?', - ignore_eos=True, - stream=True, - request_output_len=100, - temperature=0.01): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - assert output.get('finish_reason') == 'length' - assert outputList[-1].get('tokens') == 100 or outputList[-1].get('tokens') == 101 - - def test_max_tokens(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself', - request_output_len=5, - temperature=0.01): - continue - assert_chat_interactive_batch_return(output) - assert output.get('finish_reason') == 'length' - assert output.get('tokens') == 5 or output.get('tokens') == 6 - - def test_max_tokens_streaming(self): - api_client = APIClient(BASE_URL) - outputList = [] - for output in api_client.chat_interactive_v1(prompt='Hi, pls intro yourself', - stream=True, - request_output_len=5, - temperature=0.01): - outputList.append(output) - assert_chat_interactive_stream_return(outputList[-1], True, index=len(outputList) - 2) - for index in range(0, len(outputList) - 1): - assert_chat_interactive_stream_return(outputList[index], index=index) - assert output.get('finish_reason') == 'length' - assert outputList[-1].get('tokens') == 5 or outputList[-1].get('tokens') == 6 - - def test_input_validation(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi', top_p=0): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_p `0.0` must be in (0, 1].' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', top_p=1.01): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_p `1.01` must be in (0, 1].' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', top_p='test'): - continue - assert output.get('code') is None - assert 'Input should be a valid number' in str(output) - - for output in api_client.chat_interactive_v1(prompt='Hi', temperature=-0.01): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The temperature `-0.01` must be in [0, 2]' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', temperature=2.01): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The temperature `2.01` must be in [0, 2]' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', temperature='test'): - continue - assert output.get('code') is None - assert 'Input should be a valid number' in str(output) - - for output in api_client.chat_interactive_v1(prompt='Hi', top_k=-1): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_k `-1` cannot be a negative integer.' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', top_k='test'): - continue - assert output.get('code') is None - assert 'Input should be a valid integer' in str(output) - - def test_input_validation_streaming(self): - api_client = APIClient(BASE_URL) - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, top_p=0): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_p `0.0` must be in (0, 1].' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, top_p=1.01): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_p `1.01` must be in (0, 1].' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, top_p='test'): - continue - assert output.get('code') is None - assert 'Input should be a valid number' in str(output) - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, temperature=-0.01): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The temperature `-0.01` must be in [0, 2]' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, temperature=2.01): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The temperature `2.01` must be in [0, 2]' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, temperature='test'): - continue - assert output.get('code') is None - assert 'Input should be a valid number' in str(output) - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, top_k=-1): - continue - assert output.get('code') == 400 - assert output.get('message') == 'The top_k `-1` cannot be a negative integer.' - assert output.get('object') == 'error' - - for output in api_client.chat_interactive_v1(prompt='Hi', stream=True, top_k='test'): - continue - assert output.get('code') is None - assert 'Input should be a valid integer' in str(output) diff --git a/autotest/utils/run_restful_chat.py b/autotest/utils/run_restful_chat.py index d9c601d2d..ce6f822ed 100644 --- a/autotest/utils/run_restful_chat.py +++ b/autotest/utils/run_restful_chat.py @@ -1,7 +1,5 @@ import json import os -import random -import string import subprocess from time import sleep, time @@ -146,13 +144,6 @@ def run_all_step(config, cases_info, worker_id: str = '', port: int = DEFAULT_PO with assume: assert restful_result, msg - with allure.step(case + ' step3 - restful_test - interactive chat'): - active_result, interactive_log, msg = interactive_test(config, case, case_info, model, http_url, worker_id) - allure.attach.file(interactive_log, attachment_type=allure.attachment_type.TEXT) - - with assume: - assert active_result, msg - def open_chat_test(config, case, case_info, model, url, worker_id: str = ''): log_path = config.get('log_path') @@ -191,47 +182,6 @@ def open_chat_test(config, case, case_info, model, url, worker_id: str = ''): return result, restful_log, msg -def interactive_test(config, case, case_info, model, url, worker_id: str = ''): - log_path = config.get('log_path') - - interactive_log = os.path.join(log_path, 'interactive_' + model + worker_id + '_' + case + '.log') - - file = open(interactive_log, 'w') - - result = True - - api_client = APIClient(url) - file.writelines('available_models:' + ','.join(api_client.available_models) + '\n') - - # Randomly generate 6 characters and concatenate them into a string. - characters = string.digits - random_chars = ''.join(random.choice(characters) for i in range(6)) - - messages = [] - msg = '' - for prompt_detail in case_info: - prompt = list(prompt_detail.keys())[0] - new_prompt = {'role': 'user', 'content': prompt} - messages.append(new_prompt) - file.writelines('prompt:' + prompt + '\n') - - for output in api_client.chat_interactive_v1(prompt=prompt, - interactive_mode=True, - session_id=random_chars, - top_k=1, - request_output_len=256): - output_content = output.get('text') - file.writelines('output:' + output_content + '\n') - - case_result, reason = assert_result(output_content, prompt_detail.values(), model) - file.writelines('result:' + str(case_result) + ',reason:' + reason + '\n') - if not case_result: - msg += reason - result = result & case_result - file.close() - return result, interactive_log, msg - - def health_check(url): try: api_client = APIClient(url) diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 61f99065b..14b0bd455 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -324,7 +324,7 @@ def parse_args(): cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) session_len_act = ArgumentHelper.session_len(pt_group, default=2048) - prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group) rope_scaling_factor_act = ArgumentHelper.rope_scaling_factor(pt_group) dtype_act = ArgumentHelper.dtype(pt_group) @@ -390,7 +390,7 @@ def main(): session_len=session_len, rope_scaling_factor=args.rope_scaling_factor, tp=args.tp, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, dtype=args.dtype, ) elif args.backend == 'pytorch': @@ -400,7 +400,7 @@ def main(): session_len=session_len, tp=args.tp, eager_mode=args.eager_mode, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, dtype=args.dtype, ) gen_config = GenerationConfig(top_k=args.top_k, diff --git a/benchmark/profile_pipeline_api.py b/benchmark/profile_pipeline_api.py index 05314be2e..4f2f53fb6 100644 --- a/benchmark/profile_pipeline_api.py +++ b/benchmark/profile_pipeline_api.py @@ -155,7 +155,7 @@ def parse_args(): session_len_act = ArgumentHelper.session_len(pt_group, default=4096) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) - prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group) # turbomind engine args tb_group = parser.add_argument_group('TurboMind engine argument') @@ -189,7 +189,7 @@ def main(): quant_policy=args.quant_policy, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, communicator=args.communicator, ) elif args.backend == 'pytorch': @@ -201,7 +201,7 @@ def main(): tp=args.tp, thread_safe=False, eager_mode=args.eager_mode, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, ) engine = Engine(args.model_path, engine_config, csv=args.csv) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 13855499d..0908129e7 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -78,8 +78,8 @@ def __init__(self, model_path: str, engine_config: Union[PytorchEngineConfig, Tu self.tm_model = tm_model self.pbar = None - async def _inference(self, req_queue: Queue, session_id: int, temperature: float, top_p: float, top_k: int, - stream_output: bool, skip_tokenize: bool, skip_detokenize: bool): + async def _inference(self, req_queue: Queue, temperature: float, top_p: float, top_k: int, stream_output: bool, + skip_tokenize: bool, skip_detokenize: bool): model_inst = self.tm_model.create_instance() sess: Session = None for prompt, _, output_seqlen, cancel_after, sess in iter(req_queue.get_nowait, None): @@ -96,7 +96,7 @@ async def _inference(self, req_queue: Queue, session_id: int, temperature: float prev_len = 0 token_ids = input_ids.copy() - generator = model_inst.async_stream_infer(session_id, + generator = model_inst.async_stream_infer(sess.id, input_ids=input_ids, gen_config=GenerationConfig(max_new_tokens=output_seqlen, temperature=temperature, @@ -123,7 +123,7 @@ async def _inference(self, req_queue: Queue, session_id: int, temperature: float # for pytorch engine to restart a session if isinstance(model_inst, EngineInstance): - await model_inst.async_end(session_id) + await model_inst.async_end(sess.id) self.pbar.update(1) @@ -148,8 +148,7 @@ def process_request(self, requests, profiler: Profiler, concurrency, temperature # start threads tasks = [] for i in range(concurrency): - task = self._inference(req_queue, i, temperature, top_p, top_k, stream_output, skip_tokenize, - skip_detokenize) + task = self._inference(req_queue, temperature, top_p, top_k, stream_output, skip_tokenize, skip_detokenize) tasks.append(task) async def _gather_tasks(tasks): @@ -210,7 +209,7 @@ def parse_args(): session_len_act = ArgumentHelper.session_len(pt_group, default=4096) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) - prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group) quant_policy_act = ArgumentHelper.quant_policy(pt_group, default=0) dtype_act = ArgumentHelper.dtype(pt_group) @@ -249,7 +248,7 @@ def main(): quant_policy=args.quant_policy, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, dtype=args.dtype, communicator=args.communicator, ) @@ -261,7 +260,7 @@ def main(): max_batch_size=args.concurrency, tp=args.tp, eager_mode=args.eager_mode, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, quant_policy=args.quant_policy, dtype=args.dtype, distributed_executor_backend=args.distributed_executor_backend, diff --git a/docs/en/llm/api_server.md b/docs/en/llm/api_server.md index 274ec2ff2..42cdc1f27 100644 --- a/docs/en/llm/api_server.md +++ b/docs/en/llm/api_server.md @@ -151,28 +151,6 @@ for item in api_client.completions_v1(model=model_name, prompt='hi'): print(item) ``` -As for `/v1/chat/interactive`,we disable the feature by default. Please open it by setting `interactive_mode = True`. If you don't, it falls back to openai compatible interfaces. - -Keep in mind that `session_id` indicates an identical sequence and all requests belonging to the same sequence must share the same `session_id`. -For instance, in a sequence with 10 rounds of chatting requests, the `session_id` in each request should be the same. - -```python -from lmdeploy.serve.openai.api_client import APIClient -api_client = APIClient(f'http://{server_ip}:{server_port}') -messages = [ - "hi, what's your name?", - "who developed you?", - "Tell me more about your developers", - "Summarize the information we've talked so far" -] -for message in messages: - for item in api_client.chat_interactive_v1(prompt=message, - session_id=1, - interactive_mode=True, - stream=False): - print(item) -``` - ### Tools May refer to [api_server_tools](./api_server_tools.md). diff --git a/docs/zh_cn/llm/api_server.md b/docs/zh_cn/llm/api_server.md index 8bb91c619..4d4c99958 100644 --- a/docs/zh_cn/llm/api_server.md +++ b/docs/zh_cn/llm/api_server.md @@ -169,28 +169,6 @@ for item in api_client.completions_v1(model=model_name, prompt='hi'): print(item) ``` -关于 `/v1/chat/interactive` 接口,我们默认是关闭的。在使用时,请设置`interactive_mode = True`打开它。否则,它会退化为 openai 接口。 - -在交互式推理中,每个对话序列的 id 必须唯一,所有属于该独立的对话请求,必须使用相同的 id。这里的 id 对应与接口中的 `session_id`。 -比如,一个对话序列中,有 10 轮对话请求,那么每轮对话请求中的 `session_id` 都要相同。 - -```python -from lmdeploy.serve.openai.api_client import APIClient -api_client = APIClient(f'http://{server_ip}:{server_port}') -messages = [ - "hi, what's your name?", - "who developed you?", - "Tell me more about your developers", - "Summarize the information we've talked so far" -] -for message in messages: - for item in api_client.chat_interactive_v1(prompt=message, - session_id=1, - interactive_mode=True, - stream=False): - print(item) -``` - ### 工具调用 参考 [api_server_tools](./api_server_tools.md)。 diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 3377e7800..8c064c197 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -69,10 +69,6 @@ def pipeline(model_path: str, model_path = get_model(model_path, download_dir, revision) task, pipeline_class = get_task(model_path) - if task == 'vlm': - if backend_config and backend_config.enable_prefix_caching: - backend_config.enable_prefix_caching = False - logger.warning('VLM does not support prefix caching.') if type(backend_config) is not PytorchEngineConfig: # set auto backend mode @@ -80,6 +76,11 @@ def pipeline(model_path: str, backend = 'pytorch' if type(backend_config) is PytorchEngineConfig else 'turbomind' logger.info(f'Using {backend} engine') + if task == 'vlm': + if backend_config and backend_config.enable_prefix_caching: + backend_config.enable_prefix_caching = False + logger.warning('VLM does not support prefix caching.') + return pipeline_class(model_path, backend=backend, backend_config=backend_config, diff --git a/lmdeploy/cli/chat.py b/lmdeploy/cli/chat.py new file mode 100644 index 000000000..ac25bab44 --- /dev/null +++ b/lmdeploy/cli/chat.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import fire + +from lmdeploy import ChatTemplateConfig, GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline +from lmdeploy.archs import autoget_backend + + +def input_prompt(): + """Input a prompt in the consolo interface.""" + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + +def build_pipe(model_path, backend, **kwargs): + # set enable_prefix_cache + disable_prefix_cache = kwargs.pop('disable_prefix_cache', False) + kwargs.update(enable_prefix_caching=not disable_prefix_cache) + # set engine config + engine_config = None + if backend == 'turbomind': + engine_config = TurbomindEngineConfig() + for key, value in kwargs.items(): + if hasattr(TurbomindEngineConfig, key): + setattr(engine_config, key, value) + else: + engine_config = PytorchEngineConfig() + for key, value in kwargs.items(): + if hasattr(PytorchEngineConfig, key): + setattr(engine_config, key, value) + if kwargs.get('adapters', None): + from .utils import get_lora_adapters + adapters = get_lora_adapters(kwargs['adapters']) + engine_config.adapters = adapters + # set chat template config + chat_template = kwargs.get('chat_template', None) + chat_template_config = None + if chat_template: + chat_template_config = ChatTemplateConfig(model_name=chat_template) + + pipe = pipeline(model_path, + backend_config=engine_config, + chat_template_config=chat_template_config, + log_level='ERROR', + **kwargs) + return pipe + + +def build_gen_config(**kwargs): + gen_config = GenerationConfig(max_new_tokens=1024, top_k=40, top_p=0.8, temperature=0.8, repetition_penalty=1.0) + for key, value in kwargs.items(): + if hasattr(GenerationConfig, key): + setattr(gen_config, key, value) + return gen_config + + +def main(model_path, backend, **kwargs): + if backend != 'pytorch': + # set auto backend mode + backend = autoget_backend(model_path) + + pipe = build_pipe(model_path, backend, **kwargs) + gen_config = build_gen_config(**kwargs) + + quit = False + while True: + with pipe.session(gen_config) as sess: + while True: + try: + prompt = input_prompt() + except KeyboardInterrupt: + quit = True + break + if prompt == 'end': + sess.close() + break + if prompt == 'exit': + quit = True + break + resps = sess(prompt) + try: + for resp in resps: + print(resp.text, end='', flush=True) + sess.messages.append(dict(role='assistant', content=resp.text)) + except KeyboardInterrupt: + sess.stop() + finally: + print('\ncancelling the conversation') + if quit: + print('exiting...') + break + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index 6d594e1d7..3ae5e57e9 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -4,7 +4,7 @@ import os from ..version import __version__ -from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters +from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args class CLI(object): @@ -104,7 +104,7 @@ def add_parser_chat(): tp_act = ArgumentHelper.tp(pt_group) session_len_act = ArgumentHelper.session_len(pt_group) cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) - prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) # turbomind args @@ -218,39 +218,9 @@ def get_gpu_topo(): @staticmethod def chat(args): """Chat with pytorch or turbomind engine.""" - from lmdeploy.archs import autoget_backend - - chat_template_config = get_chat_template(args.chat_template) - - backend = args.backend - if backend != 'pytorch': - # set auto backend mode - backend = autoget_backend(args.model_path) - - if backend == 'pytorch': - from lmdeploy.messages import PytorchEngineConfig - from lmdeploy.pytorch.chat import run_chat - - adapters = get_lora_adapters(args.adapters) - engine_config = PytorchEngineConfig(dtype=args.dtype, - tp=args.tp, - session_len=args.session_len, - cache_max_entry_count=args.cache_max_entry_count, - adapters=adapters, - enable_prefix_caching=args.enable_prefix_caching, - device_type=args.device, - eager_mode=args.eager_mode, - quant_policy=args.quant_policy) - run_chat(args.model_path, engine_config, chat_template_config=chat_template_config) - else: - from lmdeploy.turbomind.chat import main as run_chat - kwargs = convert_args(args) - kwargs.pop('chat_template') - kwargs.pop('backend') - kwargs.pop('device') - kwargs.pop('eager_mode') - kwargs['chat_template_config'] = chat_template_config - run_chat(**kwargs) + from .chat import main + kwargs = convert_args(args) + main(**kwargs) @staticmethod def add_parsers(): diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index b4fde50eb..c0744e9e1 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -60,7 +60,7 @@ def add_parser_gradio(): max_batch_size_act = ArgumentHelper.max_batch_size(pt_group) cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) - prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group) max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') @@ -160,7 +160,7 @@ def add_parser_api_server(): max_batch_size_act = ArgumentHelper.max_batch_size(pt_group) cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group) cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group) - prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group) + prefix_caching_act = ArgumentHelper.disable_prefix_caching(pt_group) max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) ArgumentHelper.dp(pt_group) @@ -268,7 +268,7 @@ def gradio(args): cache_max_entry_count=args.cache_max_entry_count, block_size=args.cache_block_seq_len, session_len=args.session_len, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, device_type=args.device, quant_policy=args.quant_policy, eager_mode=args.eager_mode, @@ -283,7 +283,7 @@ def gradio(args): rope_scaling_factor=args.rope_scaling_factor, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, max_prefill_token_num=args.max_prefill_token_num, num_tokens_per_iter=args.num_tokens_per_iter, max_prefill_iters=args.max_prefill_iters, @@ -323,7 +323,7 @@ def api_server(args): block_size=args.cache_block_seq_len, session_len=args.session_len, adapters=adapters, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, device_type=args.device, quant_policy=args.quant_policy, eager_mode=args.eager_mode, @@ -342,7 +342,7 @@ def api_server(args): rope_scaling_factor=args.rope_scaling_factor, cache_max_entry_count=args.cache_max_entry_count, cache_block_seq_len=args.cache_block_seq_len, - enable_prefix_caching=args.enable_prefix_caching, + enable_prefix_caching=not args.disable_prefix_caching, max_prefill_token_num=args.max_prefill_token_num, communicator=args.communicator) chat_template_config = get_chat_template(args.chat_template) diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 6306a51d1..2b77f5043 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -462,13 +462,13 @@ def cache_block_seq_len(parser): 'be ignored') @staticmethod - def enable_prefix_caching(parser): + def disable_prefix_caching(parser): """Add argument enable_prefix_caching to parser.""" - return parser.add_argument('--enable-prefix-caching', + return parser.add_argument('--disable-prefix-caching', action='store_true', default=False, - help='Enable cache and match prefix') + help='Disable prefix caching') @staticmethod def num_tokens_per_iter(parser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index b3fed7036..0cd0a05b7 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -200,7 +200,7 @@ class TurbomindEngineConfig: cache_block_seq_len (int): the length of the token sequence in a k/v block, default to 64 enable_prefix_caching (bool): enable cache prompts for block reuse, - default to False + default to True quant_policy (int): default to 0. When k/v is quantized into 4 or 8 bit, set it to 4 or 8, respectively rope_scaling_factor (float): scaling factor used for dynamic ntk, @@ -236,7 +236,7 @@ class TurbomindEngineConfig: cache_max_entry_count: float = 0.8 cache_chunk_size: int = -1 cache_block_seq_len: int = 64 - enable_prefix_caching: bool = False + enable_prefix_caching: bool = True quant_policy: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: bool = False @@ -326,7 +326,7 @@ class PytorchEngineConfig: adapters: Dict[str, str] = None max_prefill_token_num: int = 4096 thread_safe: bool = False - enable_prefix_caching: bool = False + enable_prefix_caching: bool = True device_type: str = 'cuda' eager_mode: bool = False custom_module_map: Dict[str, str] = None diff --git a/lmdeploy/profiler.py b/lmdeploy/profiler.py index 64cfb07a5..12dcd2360 100644 --- a/lmdeploy/profiler.py +++ b/lmdeploy/profiler.py @@ -1,29 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. import csv import time -from typing import List +from itertools import count +from typing import List, Optional import numpy as np class Session: - UNKNOWN = 0 - SUCCESS = 1 - FAIL = 2 + UNKNOWN: int = 0 + SUCCESS: int = 1 + FAIL: int = 2 + ID = count(0) - def __init__(self, input_len, req_output_len): + def __init__(self, input_len: int, req_output_len: int, session_id: Optional[int] = None): self.ts = [] self.ns = [] self.input_len = input_len self.req_output_len = req_output_len self.status = Session.UNKNOWN + self.id = session_id if session_id else next(Session.ID) - def tick(self, n_token): + def tick(self, n_token: int): self.ts.append(time.perf_counter()) self.ns.append(n_token) - def finish(self, status): + def finish(self, status: int): self.status = status @@ -33,6 +36,7 @@ def __init__(self, stream_output: bool, percentages: List[int]): self.sessions: List[Session] = [] self.stream_output = stream_output self.percentages = percentages + self.session_id = count(0) def new_session(self, *args, **kwargs): sess = Session(*args, **kwargs) diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py deleted file mode 100644 index 903598378..000000000 --- a/lmdeploy/pytorch/chat.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import asyncio -import os -import random -from typing import Optional - -from lmdeploy.messages import GenerationConfig, PytorchEngineConfig -from lmdeploy.model import ChatTemplateConfig -from lmdeploy.serve.async_engine import get_names_from_model - -os.environ['TM_LOG_LEVEL'] = 'ERROR' - - -def input_prompt(chat_template_name): - """Input a prompt in the consolo interface.""" - if chat_template_name == 'codellama': - print('\nenter !! to end the input >>>\n', end='') - sentinel = '!!' - else: - print('\ndouble enter to end input >>> ', end='') - sentinel = '' # ends when this string is seen - return '\n'.join(iter(input, sentinel)) - - -def run_chat(model_path: str, - engine_config: PytorchEngineConfig, - gen_config: GenerationConfig = None, - session_id: int = 1, - trust_remote_code: bool = True, - chat_template_config: Optional[ChatTemplateConfig] = None): - """An example to perform model inference through the command line - interface. - - Args: - model_path (str): the huggingface model path. - engine_config (PytorchEngineConfig): Config of engine. - gen_config (GenerationConfig): Config of generation. - session_id (int): the identical id of a session. - trust_remote_code (bool): trust remote code. - """ - from lmdeploy import pipeline - - if gen_config is None: - gen_config = GenerationConfig(do_sample=True) - - adapter_name = None - if engine_config.adapters is not None: - adapter_name = next(iter(engine_config.adapters.keys())) - - chat_count = 0 - - def __reset_chat_state(): - """reset chat state.""" - nonlocal chat_count - seed = random.getrandbits(64) - gen_config.random_seed = seed - - async def __generate(prompt: str): - """chat generate.""" - nonlocal chat_count - print() - async for out in pipe.generate( - prompt, - session_id, - gen_config=gen_config, - sequence_start=chat_count == 0, - sequence_end=False, - adapter_name=adapter_name, - ): - print(f'{out.response}', end='', flush=True) - print() - chat_count += 1 - - async def __chat_step(prompt: str): - """chat step.""" - if prompt == 'exit': - exit(0) - elif prompt == 'end': - await pipe.stop_session(session_id) - __reset_chat_state() - else: - await __generate(prompt) - - async def __chat_loop(model_path: str): - """chat loop.""" - __reset_chat_state() - _, chat_template_name = get_names_from_model(model_path) - while True: - prompt = input_prompt(chat_template_name) - await __chat_step(prompt) - - with pipeline( - model_path, - backend_config=engine_config, - chat_template_config=chat_template_config, - ) as pipe: - try: - asyncio.run(__chat_loop(model_path)) - except KeyboardInterrupt: - exit(0) - - -def main(model_path: str, - session_id: int = 1, - top_k: float = 40, - top_p: float = 0.8, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - tp: int = 1, - adapter: str = None, - trust_remote_code: bool = True, - chat_template: str = None): - """An example to perform model inference through the command line - interface. - - Args: - model_path (str): the huggingface model path - session_id (int): the identical id of a session - top_k (int): sampling top k. - top_p (int): sampling top p. - temperature (float): sampling temperature. - repetition_penalty (float): parameter to penalize repetition - tp (int): GPU number used in tensor parallelism - adapter (str): path to lora adapter. - trust_remote_code (bool): Trust remote code. - chat_template (str): A JSON file or string that specifies the - chat template configuration. - """ - adapters = None - if adapter is not None: - adapters = dict(default=adapter) - engine_config = PytorchEngineConfig(tp=tp, adapters=adapters) - gen_config = GenerationConfig(max_new_tokens=512, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=False) - chat_template_config = None - if chat_template is not None and os.path.exists(chat_template): - chat_template_config = ChatTemplateConfig.from_json(chat_template) - return run_chat(model_path, - engine_config, - gen_config, - session_id=session_id, - trust_remote_code=trust_remote_code, - chat_template_config=chat_template_config) - - -if __name__ == '__main__': - import fire - - fire.Fire(main) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index dea2a1877..9817d12ca 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -78,7 +78,7 @@ class CacheConfig: window_size: int = -1 cache_max_entry_count: float = 0.8 max_prefill_token_num: int = 4096 - enable_prefix_caching: bool = False + enable_prefix_caching: bool = True quant_policy: Literal[0, 4, 8] = 0 device_type: str = 'cuda' diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 19dfc0617..ad2900ed3 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -12,7 +12,7 @@ from itertools import count from queue import Queue from threading import Thread -from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union import torch import tqdm @@ -97,30 +97,19 @@ class Session: """Session for AsyncEngine.chat. Args: - _id (int): session_id for internal use. - _step (int): the offset of the k/v cache for internal use. - _prompt (Any): input prompt for internal use. - _response (Reaponse): model output for prompt. - _engine (Any): engine for internal use. - history (List[Any, str]): chat history. + _id (int): session_id for internal use + _engine (Any): engine for internal use + _response (Reaponse): model output for prompt + _gen_config (GenerationConfig): the generation config + messages (List[Dict]): chat history in openai format """ def __init__(self, session_id: int, engine: Any, gen_config: GenerationConfig = None): self._id: int = session_id self._engine = engine - self._step: int = 0 - self._prompt: Any = None self._response: Response = None self._gen_config = gen_config - self.history: List[Tuple[Any, str]] = [] - - def _merge_response(self, resp: Response, step: Union[Response, GenOut]): - """merge response.""" - resp.text += step.text if isinstance(step, Response) else step.response - resp.input_token_len = step.input_token_len - resp.generate_token_len = step.generate_token_len - resp.finish_reason = step.finish_reason - return resp + self.messages: List[Dict] = [] @property def response(self) -> Response: @@ -132,14 +121,13 @@ def close(self): if self._engine: self._engine._run(coro=self._engine.end_session(self._id)).result() self._engine = None + self.messages = [] - def __repr__(self) -> str: - res = '' - for user, assistant in self.history: - if isinstance(user, list): - user = str(user) - res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n' - return res + def stop(self): + """stop the session while tokens are being generated.""" + if self._engine: + self._engine._run(coro=self._engine.stop_session(self._id)).result() + self.messages = [] def __enter__(self): return self @@ -152,7 +140,8 @@ def __call__(self, gen_config: Optional[GenerationConfig] = None, stream_response: bool = True, do_preprocess: bool = True) -> Union[Response, Iterator[Response]]: - self._engine.chat(prompt=prompt, + self.messages.append(dict(role='user', content=prompt)) + self._engine.chat(prompt=self.messages, gen_config=gen_config or self._gen_config, stream_response=stream_response, do_preprocess=do_preprocess, @@ -839,7 +828,7 @@ def session(self, gen_config: GenerationConfig = None): return Session(self._run(fn=lambda: next(self._session_id)).result(), engine=self, gen_config=gen_config) def chat(self, - prompt: str, + prompt: Union[List[Dict], str], session=None, gen_config: Optional[GenerationConfig] = None, stream_response=False, @@ -847,7 +836,7 @@ def chat(self, """Chat. Args: - prompt (str): prompt + prompt (Union[List[Dict], str]): it can be an openai-like message or a string session (Session): the chat session gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. @@ -858,16 +847,19 @@ def chat(self, if session is None: session = self.session() - # sync & init - session._prompt = prompt - session._response = None + if isinstance(prompt, str): + session.messages.append(dict(role='user', content=prompt)) + elif isinstance(prompt, List) and all(isinstance(_, Dict) for _ in prompt): + session.messages.extend(prompt) + else: + raise ValueError(f'unsupported prompt: {prompt}') - sequence_start = session._step == 0 + session._response = None - generator = self.infer(prompt, + generator = self.infer(session.messages, gen_config, - sequence_start=sequence_start, - sequence_end=False, + sequence_start=True, + sequence_end=True, session_id=session._id, stream_response=stream_response, multiplex=True) @@ -883,8 +875,10 @@ def _gen(): raise else: session._response = resp - session._step += resp.generate_token_len + resp.input_token_len - session.history.append((session._prompt, resp.text)) + session.messages.append(dict(role='user', content=resp.text)) + # Since prefix caching is used to substitute interactive mode, the context step should be + # reset after each round + self.id2step[session._id] = 0 if stream_response: session.generator = _gen() diff --git a/lmdeploy/serve/openai/api_client.py b/lmdeploy/serve/openai/api_client.py index 79fd04570..21ca040d2 100644 --- a/lmdeploy/serve/openai/api_client.py +++ b/lmdeploy/serve/openai/api_client.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import json -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import requests @@ -168,70 +168,6 @@ def chat_completions_v1(self, output = json_loads(decoded) yield output - def chat_interactive_v1(self, - prompt: Union[str, List[Dict[str, str]]], - image_url: Optional[Union[str, List[str]]] = None, - session_id: int = -1, - interactive_mode: bool = False, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = None, - request_output_len: Optional[int] = None, - top_p: float = 0.8, - top_k: int = 40, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - ignore_eos: bool = False, - skip_special_tokens: Optional[bool] = True, - adapter_name: Optional[str] = None, - **kwargs): - """Interactive completions. - - - On interactive mode, the chat history is kept on the server. Please - set `interactive_mode = True`. - - On normal mode, no chat history is kept on the server. Set - `interactive_mode = False`. - - Args: - prompt: the prompt to use for the generation. - image_url (str | List[str] | None): the image url or base64 encoded - string for VL models. - session_id: determine which instance will be called. - If not specified with a value other than -1, using random value - directly. - interactive_mode (bool): turn on interactive mode or not. On - interactive mode, session history is kept on the server (and - vice versa). - stream: whether to stream the results or not. - stop (str | List[str] | None): To stop generating further tokens. - Only accept stop words that's encoded to one token idex. - request_output_len (int): output token nums. If not specified, - will use maximum possible number for a session. - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or - higher are kept for generation. - top_k (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - temperature (float): to modulate the next token probability - repetition_penalty (float): The parameter for repetition penalty. - 1.0 means no penalty - ignore_eos (bool): indicator for ignoring eos - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - adapter_name (str): For slora inference. Choose which lora to do - the inference. - - Yields: - json objects consist of text, tokens, input_tokens, - history_tokens, finish_reason - """ - pload = {k: v for k, v in locals().copy().items() if k[:2] != '__' and k not in ['self']} - response = requests.post(self.chat_intractive_v1_url, headers=self.headers, json=pload, stream=stream) - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'): - if chunk: - decoded = chunk.decode('utf-8') - output = json_loads(decoded) - yield output - def completions_v1( self, model: str, @@ -304,7 +240,7 @@ def completions_v1( yield output def chat(self, - prompt: str, + messages: List[str], session_id: int, image_url: Optional[Union[str, List[str]]] = None, request_output_len: int = 512, @@ -317,7 +253,7 @@ def chat(self, """Chat with a unique session_id. Args: - prompt: the prompt to use for the generation. + messages(List): the chat context, including history session_id: determine which instance will be called. If not specified with a value other than -1, using random value directly. @@ -340,35 +276,23 @@ def chat(self, text, tokens, finish_reason """ assert session_id != -1, 'please set a value other than -1' - for outputs in self.chat_interactive_v1(prompt, + for outputs in self.chat_completions_v1(model=self.available_models[0], + messages=messages, + temperature=temperature, + top_p=top_p, session_id=session_id, image_url=image_url, - request_output_len=request_output_len, - interactive_mode=True, - stream=stream, + max_tokens=request_output_len, + stream=True, top_k=top_k, - top_p=top_p, - temperature=temperature, repetition_penalty=repetition_penalty, ignore_eos=ignore_eos): - if outputs['finish_reason'] == 'length' and outputs['tokens'] == 0: + finish_reason = outputs['choices'][0]['finish_reason'] + content = outputs['choices'][0]['delta']['content'] + if finish_reason == 'length' and content == '': print('WARNING: exceed session max length.' ' Please end the session.') - yield outputs['text'], outputs['tokens'], outputs['finish_reason'] - - def end_session(self, session_id: int): - """End the session with a unique session_id. - - Args: - session_id: determine which instance will be called. - If not specified with a value other than -1, using random value - directly. - """ - for out in self.chat_interactive_v1(prompt='', - session_id=session_id, - request_output_len=0, - interactive_mode=False): - pass + yield content, finish_reason def input_prompt(): @@ -378,62 +302,31 @@ def input_prompt(): return '\n'.join(iter(input, sentinel)) -def get_streaming_response(prompt: str, - api_url: str, - session_id: int, - request_output_len: int = 512, - stream: bool = True, - interactive_mode: bool = False, - ignore_eos: bool = False, - cancel: bool = False, - top_p: float = 0.8, - temperature: float = 0.7, - api_key: Optional[str] = None) -> Iterable[List[str]]: - headers = {'User-Agent': 'Test Client'} - if api_key is not None: - headers['Authorization'] = f'Bearer {api_key}' - pload = { - 'prompt': prompt, - 'stream': stream, - 'session_id': session_id, - 'request_output_len': request_output_len, - 'interactive_mode': interactive_mode, - 'ignore_eos': ignore_eos, - 'cancel': cancel, - 'top_p': top_p, - 'temperature': temperature - } - response = requests.post(api_url, headers=headers, json=pload, stream=stream) - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b'\n'): - if chunk: - data = json_loads(chunk.decode('utf-8')) - output = data.pop('text', '') - tokens = data.pop('tokens', 0) - finish_reason = data.pop('finish_reason', None) - yield output, tokens, finish_reason - - def main(api_server_url: str = 'http://0.0.0.0:23333', session_id: int = 0, api_key: Optional[str] = None): """Main function to chat in terminal.""" if not api_server_url.startswith('http://'): - print(f'[WARNING] api_server_url of the api_server should ' - f'start with "http://", but got "{api_server_url}"') + print(f'[WARNING] api_server_url should start with "http://", but got "{api_server_url}"') # noqa: E231 api_server_url = 'http://' + api_server_url.strip() api_client = APIClient(api_server_url, api_key=api_key) + messages = [] while True: prompt = input_prompt() if prompt in ['exit', 'end']: - api_client.end_session(session_id) + messages = [] if prompt == 'exit': exit(0) else: - for text, tokens, finish_reason in api_client.chat(prompt, - session_id=session_id, - request_output_len=512, - stream=True): + messages.append(dict(role='user', content=prompt)) + response = [] + for text, finish_reason in api_client.chat(messages, + session_id=session_id, + request_output_len=512, + stream=True): if finish_reason == 'length': continue print(text, end='') + response.append(text) + messages.append(dict(role='assistant', content=''.join(response))) if __name__ == '__main__': diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 78c607850..8a991e5ce 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -29,8 +29,8 @@ CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse, - GenerateRequest, GenerateResponse, LogProbs, ModelCard, ModelList, - ModelPermission, TopLogprob, UsageInfo) + GenerateRequest, LogProbs, ModelCard, ModelList, ModelPermission, + TopLogprob, UsageInfo) from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager from lmdeploy.tokenizer import DetokenizeState, Tokenizer @@ -355,7 +355,7 @@ async def chat_completions_v1(raw_request: Request = None): if error_check_ret is not None: return error_check_ret if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.') + return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id} is occupied.') model_name = request.model adapter_name = None @@ -665,7 +665,7 @@ async def completions_v1(raw_request: Request = None): if error_check_ret is not None: return error_check_ret if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.') + return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id} is occupied.') model_name = request.model adapter_name = None @@ -912,142 +912,8 @@ async def free_cache(raw_request: Request) -> JSONResponse: @router.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)]) async def chat_interactive_v1(request: GenerateRequest, raw_request: Request = None): - """Generate completion for the request. - - - On interactive mode, the chat history is kept on the server. Please set - `interactive_mode = True`. - - On normal mode, no chat history is kept on the server. Set - `interactive_mode = False`. - - The request should be a JSON object with the following fields: - - prompt: the prompt to use for the generation. - - image_url(str | List[str] | None): the image url or base64 encoded string - for VL models. - - session_id: determine which instance will be called. If not specified - with a value other than -1, using random value directly. - - interactive_mode (bool): turn on interactive mode or not. On interactive - mode, session history is kept on the server (and vice versa). - - stream: whether to stream the results or not. - - stop (str | List[str] | None): To stop generating further - tokens. Only accept stop words that's encoded to one token idex. - - request_output_len (int): output token nums. If not specified, will use - maximum possible number for a session. - - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. - - top_k (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - - temperature (float): to modulate the next token probability - - repetition_penalty (float): The parameter for repetition penalty. - 1.0 means no penalty - - ignore_eos (bool): indicator for ignoring eos - - skip_special_tokens (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - - spaces_between_special_tokens (bool): Whether or not to add spaces - around special tokens. The behavior of Fast tokenizers is to have - this to False. This is setup to True in slow tokenizers. - - adapter_name (str): For slora inference. Choose which lora to do the - inference. - - min_new_tokens (int): To generate at least numbers of tokens. - - min_p (float): Minimum token probability, which will be scaled by the - probability of the most likely token. It must be a value between - 0 and 1. Typical values are in the 0.01-0.2 range, comparably - selective as setting `top_p` in the 0.99-0.8 range (use the - opposite of normal `top_p` values) - """ - if request.cancel: - if request.session_id != -1: - await VariableInterface.async_engine.stop_session(request.session_id) - return {'text': '', 'tokens': 0, 'input_tokens': 0, 'history_tokens': 0, 'finish_reason': 'stop'} - else: - return create_error_response(HTTPStatus.BAD_REQUEST, 'please set a session_id to cancel a request') - error_check_ret = await check_request(request) - if error_check_ret is not None: - return error_check_ret - if request.session_id == -1: - VariableInterface.session_id += 1 - request.session_id = VariableInterface.session_id - - async_engine = VariableInterface.async_engine - sequence_start = async_engine.id2step.get(request.session_id, 0) == 0 - sequence_end = not request.interactive_mode - if isinstance(request.stop, str): - request.stop = [request.stop] - - end_session = sequence_end and request.prompt == '' and request.request_output_len == 0 - if end_session: - await async_engine.end_session(request.session_id) - return JSONResponse(dict(text='', tokens=0, input_tokens=0, history_tokens=0, finish_reason='stop')) - - random_seed = request.seed if request.seed else None - - gen_config = GenerationConfig(max_new_tokens=request.request_output_len, - do_sample=True, - top_p=request.top_p, - top_k=request.top_k, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, - ignore_eos=request.ignore_eos, - stop_words=request.stop, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=request.spaces_between_special_tokens, - min_new_tokens=request.min_new_tokens, - min_p=request.min_p, - random_seed=random_seed) - if request.image_url: - from lmdeploy.vl import load_image - if isinstance(request.image_url, List): - request.prompt = (request.prompt, [load_image(url) for url in request.image_url]) - else: - request.prompt = (request.prompt, load_image(request.image_url)) - if not hasattr(async_engine, '_convert_prompts'): - return create_error_response(HTTPStatus.BAD_REQUEST, '`image_url` argument only works for VL model') - request.prompt = async_engine._convert_prompts(request.prompt) - generation = async_engine.generate( - request.prompt, - request.session_id, - gen_config=gen_config, - stream_response=True, # always use stream to enable batching - sequence_start=sequence_start, - sequence_end=sequence_end, - adapter_name=request.adapter_name) - - # Streaming case - async def stream_results() -> AsyncGenerator[bytes, None]: - async for out in generation: - chunk = GenerateResponse(text=out.response, - tokens=out.generate_token_len, - input_tokens=out.input_token_len, - history_tokens=out.history_token_len, - finish_reason=out.finish_reason) - data = chunk.model_dump_json() - yield f'{data}\n' - - if request.stream: - return StreamingResponse(stream_results(), media_type='text/event-stream') - else: - ret = {} - text = '' - tokens, input_tokens, history_tokens = 0, 0, 0 - finish_reason = None - async for out in generation: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await async_engine.stop_session(request.session_id) - return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') - text += out.response - tokens = out.generate_token_len - input_tokens = out.input_token_len - history_tokens = out.history_token_len - finish_reason = out.finish_reason - ret = { - 'text': text, - 'tokens': tokens, - 'input_tokens': input_tokens, - 'history_tokens': history_tokens, - 'finish_reason': finish_reason - } - return JSONResponse(ret) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'v1/chat/interactive is removed, pleease use v1/chat/completions instead') def handle_torchrun(): diff --git a/lmdeploy/serve/utils.py b/lmdeploy/serve/utils.py index 9b6b37431..3cf9ba409 100644 --- a/lmdeploy/serve/utils.py +++ b/lmdeploy/serve/utils.py @@ -43,29 +43,76 @@ def get_reward_score(self, input_ids: List) -> List[float]: async def _async_get_logits(self, input_ids, steps: List[int] = None, + max_input_len: int = None, sequence_start: bool = True, sequence_end: bool = True) -> List[torch.Tensor]: assert input_ids and all(isinstance(_, List) for _ in input_ids) assert steps is None or (len(steps) == len(input_ids)) + steps = steps or [0] * len(input_ids) + max_input_len = max_input_len or max([len(x) for x in input_ids]) + + if self.backend == 'turbomind': + logits = await self._async_get_logits_by_turbomind(input_ids, steps, max_input_len) + else: + logits = await self._async_get_logits_by_pytorch(input_ids, steps, max_input_len, sequence_start, + sequence_end) + return logits + + async def _async_get_logits_by_turbomind(self, input_ids, steps, max_input_len): + assert len(input_ids) == len(steps) + + if any(s != 0 for s in steps): + assert self.backend_config.enable_prefix_caching, 'please enable prefix caching' + assert all(s % self.backend_config.cache_block_seq_len == 0 for s in steps) + logits = [None] * len(input_ids) + gen_config = GenerationConfig(max_new_tokens=1, output_logits='all', do_sample=False) async def _proc(i): - async with self.model_inst(session_id=i) as inst: - input_len = len(input_ids[i]) - # TODO(lvhan): Fix the ugly code later on - max_new_tokens = 1 if self.backend == 'turbomind' else 0 + session_id = next(self._session_id) + async with self.model_inst(session_id=session_id) as inst: + token_ids = input_ids[i][:steps[i] + max_input_len] + input_len = len(token_ids) + async with self.safe_run(inst, + session_id=session_id, + input_ids=token_ids, + gen_config=gen_config, + stream_output=False, + step=steps[i]) as gen: + async for outputs in gen: + pass + logits[i] = outputs.logits[:input_len - steps[i], :] + + tasks = [_proc(i) for i in range(len(input_ids))] + await asyncio.gather(*tasks) + + return logits + + async def _async_get_logits_by_pytorch(self, + input_ids: List[List[int]], + steps: List[int], + max_input_len: int, + sequence_start: bool = True, + sequence_end: bool = True): + logits = [None] * len(input_ids) + + async def _proc(i): + session_id = next(self._session_id) + async with self.model_inst(session_id=session_id) as inst: + token_ids = input_ids[i][steps[i]:steps[i] + max_input_len] + input_len = len(token_ids) # The reason to set `top_k=1` is that pt engine crashes at top_k sampling stage # when perform inference on a reward model. - gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all', top_k=1) + gen_config = GenerationConfig(max_new_tokens=0, output_logits='all', top_k=1) async with self.safe_run(inst, - session_id=i, - input_ids=input_ids[i], + session_id=session_id, + input_ids=token_ids, gen_config=gen_config, stream_output=False, sequence_start=sequence_start, sequence_end=sequence_end, - step=steps[i] if steps else 0) as gen: + step=steps[i]) as gen: async for outputs in gen: pass logits[i] = outputs.logits[:input_len, :] @@ -73,7 +120,7 @@ async def _proc(i): session_ids = list(range(len(input_ids))) tasks = [_proc(i) for i in range(len(input_ids))] await asyncio.gather(*tasks) - if sequence_end and self.backend == 'pytorch': + if sequence_end: for session_id in session_ids: await self.end_session(session_id) return logits @@ -94,10 +141,7 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]: input_ids = [input_ids] assert all(len(_) > 1 for _ in input_ids) - # TODO: a better way to determine `max_input_len`, at most allocate - # 2G mem for logits with shape [bs, max_input_len, vocab_size] - vocab_size = self.hf_tm_cfg.vocab_size - max_input_len = 2 * 1024**3 // (vocab_size * 4) + max_input_len = self.backend_config.max_prefill_token_num sizes = [len(_) for _ in input_ids] result = [] sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True) @@ -113,10 +157,8 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]: result.append(res) else: _input_ids = [input_ids[indices[i]] for i in range(start, end)] - res = self._get_ppl( - input_ids=_input_ids, - max_input_len=max_input_len, - ) + steps = [0] * len(_input_ids) + res = self._get_ppl(input_ids=_input_ids, steps=steps, max_input_len=max_input_len) result.extend(res) output = list(range(len(result))) for index, sorted_index in enumerate(indices): @@ -154,55 +196,36 @@ def _get_long_text_ppl(self, input_ids, max_input_len): losses = [] target_counts = [] for i in range(0, seq_len, max_input_len): - token_ids = input_ids[i:i + max_input_len] - step = [i] - # shift token_ids by 1 to the left - target_ids = input_ids[i + 1:i + 1 + max_input_len] - loss = self._get_ppl(input_ids=[token_ids], - max_input_len=len(token_ids), - target_ids=[target_ids], - steps=step, - sequence_start=(i == 0), - sequence_end=False) + loss, target_count = self._get_ppl(input_ids=[input_ids], + steps=[i], + max_input_len=max_input_len, + sequence_start=(i == 0), + sequence_end=False) losses.extend(loss) - target_counts.append(len(target_ids)) + target_counts.extend(target_count) losses = [loss * target_count for loss, target_count in zip(losses, target_counts)] loss_sum = sum(losses) target_count = sum(target_counts) return loss_sum / target_count - def _get_ppl(self, - input_ids, - max_input_len, - target_ids=None, - steps=None, - sequence_start: bool = True, - sequence_end: bool = True): - assert (isinstance(input_ids, List) and all(isinstance(_, List) for _ in input_ids)) - assert steps is None or len(steps) == len(input_ids) - assert target_ids is None or len(target_ids) == len(input_ids) - - lens = [len(_) for _ in input_ids] - total_len = sum(lens) - assert sum(lens) <= max_input_len - - logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, ' - f'total_len: {total_len}, steps: {steps}') + def _get_ppl(self, input_ids, steps, max_input_len, sequence_start: bool = True, sequence_end: bool = True): + assert isinstance(steps, List) and len(steps) == len(input_ids) + torch.cuda.empty_cache() - logits = self._run(coro=self._async_get_logits( - input_ids=input_ids, steps=steps, sequence_start=sequence_start, sequence_end=sequence_end)).result() + logits = self._run(coro=self._async_get_logits(input_ids=input_ids, + steps=steps, + max_input_len=max_input_len, + sequence_start=sequence_start, + sequence_end=sequence_end)).result() padding_token_id = -100 - if target_ids is None: - target_ids = [x[1:] + [padding_token_id] for x in input_ids] - else: - target_ids = [ - target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i] - for i in range(len(input_ids)) - ] - target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids] + # shift token_ids by 1 to the left + target_ids = [s[steps[i] + 1:steps[i] + 1 + max_input_len] for i, s in enumerate(input_ids)] + target_ids = [t + [padding_token_id] if len(t) < max_input_len else t for t in target_ids] + target_ids = [torch.Tensor(torch.LongTensor(t)) for t in target_ids] result = [] + target_counts = [] for _logits, _target_ids in zip(logits, target_ids): _logits = _logits.float() vocab_size = _logits.shape[-1] @@ -218,5 +241,6 @@ def _get_ppl(self, loss = flat_loss_matrix.sum() target_count = target_mask.sum() result.append(loss.item() / target_count.item()) + target_counts.append(target_count) logger.info(f'ppl result: {result}') - return result + return result, target_counts diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index b92a72cbf..27a33186f 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -210,10 +210,6 @@ def chat(self, prompts: VLPromptType, *args, **kwargs): _prompts = self._convert_prompts(prompts) sess = super().chat(_prompts, *args, **kwargs) - # recover prompts & history - sess._prompt = prompts - last_round = sess.history[-1] - sess.history[-1] = (prompts, last_round[-1]) return sess @classmethod diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py deleted file mode 100644 index dd4c1fe3a..000000000 --- a/lmdeploy/turbomind/chat.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import asyncio -import os -import random - -from lmdeploy import Tokenizer -from lmdeploy.archs import get_model_arch -from lmdeploy.messages import GenerationConfig, TurbomindEngineConfig -from lmdeploy.model import ChatTemplateConfig -from lmdeploy.serve.async_engine import get_names_from_model -from lmdeploy.tokenizer import DetokenizeState -from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg - -log_level = 'ERROR' -if os.getenv('TM_LOG_LEVEL') is None: - os.environ['TM_LOG_LEVEL'] = log_level - from lmdeploy.utils import get_logger - logger = get_logger('lmdeploy') - logger.setLevel(log_level) - - -def input_prompt(model_name): - """Input a prompt in the consolo interface.""" - if model_name == 'codellama': - print('\nenter !! to end the input >>>\n', end='') - sentinel = '!!' - else: - print('\ndouble enter to end input >>> ', end='') - sentinel = '' # ends when this string is seen - return '\n'.join(iter(input, sentinel)) - - -async def async_infer(generator, session_id, input_ids, gen_config, sequence_start, step, stream_output, tokenizer, - state): - token_ids = input_ids.copy() - prev_len = 0 - async for output in generator.async_stream_infer(session_id=session_id, - input_ids=input_ids, - gen_config=gen_config, - sequence_start=sequence_start, - sequence_end=False, - step=step, - stream_output=stream_output): - tokens = output.num_token - if tokens > prev_len: - token_ids += output.token_ids[prev_len - tokens:] - response, state = tokenizer.detokenize_incrementally(token_ids, state=state) - prev_len = tokens - print(response, end='', flush=True) - return tokens - - -def main(model_path: str, - session_id: int = 1, - top_k: float = 40, - top_p: float = 0.8, - temperature: float = 0.8, - repetition_penalty: float = 1.0, - cap: str = 'chat', - dtype: str = 'auto', - tp: int = 1, - model_format: str = None, - quant_policy: int = 0, - cache_max_entry_count: float = 0.8, - cache_block_seq_len: int = 64, - rope_scaling_factor: float = 0.0, - enable_prefix_caching: bool = False, - session_len: int = None, - stream_output: bool = True, - request_output_len: int = 1024, - chat_template_config: ChatTemplateConfig = None, - communicator: str = 'nccl', - **kwargs): - """An example to perform model inference through the command line - interface. - - Args: - model_path (str): the path of the deployed model - session_id (int): the identical id of a session - top_k (int): sampling top k. - top_p (int): sampling top p. - temperature (float): sampling temperature. - repetition_penalty (float): parameter to penalize repetition - cap (str): the capability of a model. For example, codellama has the - ability among ['completion', 'infilling', 'chat', 'python'] - dtype (str): data type for model weights and activations. It can be - one of the following values, ['auto', 'float16', 'bfloat16'] - The `auto` option will use FP16 precision for FP32 and FP16 - models, and BF16 precision for BF16 models. - tp (int): GPU number used in tensor parallelism - model_format (str): the layout of the deployed model. It can be one - of the following values [hf, llama, awq] - quant_policy (int): default to 0. When k/v is quantized into 4 or 8 - bit, set it to 4 or 8, respectively - cache_max_entry_count (float): the percentage of gpu memory occupied - by the k/v cache. - cache_block_seq_len (int): the length of the token sequence in a k/v - block, default to 64 - rope_scaling_factor (float): scaling factor used for dynamic ntk, - default to 0. TurboMind follows the implementation of transformer - LlamaAttention - enable_prefix_caching (bool): whether enable prefix caching - session_len (int): the length input output tokens - stream_output (bool): indicator for streaming output or not - request_output_len (int): output token nums - chat_template_config (ChatTemplateConfig): chat template config - kwargs (dict): unused args - """ - - # chat template - _, chat_template_name = get_names_from_model(model_path) - if chat_template_config is None: - chat_template_config = ChatTemplateConfig(chat_template_name) - elif chat_template_config.model_name is None: - chat_template_config.model_name = chat_template_name - if chat_template_config.capability is None: - chat_template_config.capability = cap - print('chat_template_config:\n', chat_template_config, sep='', flush=True) - model = chat_template_config.chat_template - - _, model_config = get_model_arch(model_path) - session_len = _get_and_verify_max_len(model_config, session_len) - - # engine - engine_cfg = TurbomindEngineConfig(max_batch_size=1, - model_format=model_format, - session_len=session_len, - cache_max_entry_count=cache_max_entry_count, - cache_block_seq_len=cache_block_seq_len, - enable_prefix_caching=enable_prefix_caching, - quant_policy=quant_policy, - rope_scaling_factor=rope_scaling_factor, - dtype=dtype, - tp=tp, - communicator=communicator) - print('engine_cfg:\n', engine_cfg, sep='', flush=True) - tokenizer = Tokenizer(model_path) - from lmdeploy import turbomind as tm - tm_model = tm.TurboMind.from_pretrained(model_path, tokenizer=tokenizer, engine_config=engine_cfg) - generator = tm_model.create_instance() - - # generation config - gen_config = GenerationConfig(max_new_tokens=request_output_len, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty) - stop_words = _stop_words(model.stop_words, tokenizer) - gen_config.convert_stop_bad_words_to_ids(tokenizer) - if stop_words is not None: - stop_words = stop_words[0][0].tolist() - if gen_config.stop_token_ids is None: - gen_config.stop_token_ids = stop_words - hf_gen_cfg = get_hf_gen_cfg(model_path) - gen_config.update_from_hf_gen_cfg(hf_gen_cfg, tokenizer.eos_token_id) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - nth_round = 1 - step = 0 - seed = random.getrandbits(64) - while True: - prompt = input_prompt(chat_template_name) - if prompt == 'exit': - exit(0) - elif prompt == 'end': - loop.run_until_complete(generator.async_end(session_id)) - nth_round = 1 - step = 0 - seed = random.getrandbits(64) - else: - prompt = model.get_prompt(prompt, nth_round == 1) - input_ids = tokenizer.encode(prompt, nth_round == 1) - gen_config.random_seed = seed - - if model.capability == 'chat': - sequence_start = (nth_round == 1) - else: - sequence_start = True - step = 0 - - if step + len(input_ids) + request_output_len >= tm_model.session_len: - print('WARNING: exceed session max length.' - ' Please end the session.') - continue - - print(f'{prompt}', end='', flush=True) - state = DetokenizeState(len(input_ids)) - - coro = async_infer(generator, session_id, input_ids, gen_config, sequence_start, step, stream_output, - tokenizer, state) - tokens = loop.run_until_complete(coro) - - # update step - step += len(input_ids) + tokens - print() - - nth_round += 1 - - -if __name__ == '__main__': - import fire - - fire.Fire(main) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 3ff1dc143..4aceee643 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -223,8 +223,7 @@ def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: Tu # pack `self.config` and `self.engine_config` into a dict self.config_dict = self.config.to_dict() self.config_dict.update(dict(engine_config=asdict(self.engine_config))) - logger.info(f'turbomind model config:\n\n' - f'{json.dumps(self.config_dict, indent=2)}') + logger.info(f'turbomind model config:\n\n{json.dumps(self.config_dict, indent=2)}') def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: TurbomindEngineConfig): """Load model which is in hf format.""" @@ -526,15 +525,8 @@ def prepare_inputs(self, async def async_cancel(self, session_id: int = None): self.model_inst.cancel() - def async_end_cb(self, fut: asyncio.Future, status: int): - """executing on engine's signaling thread.""" - logger.info(f'[async_end_cb] session ended, status = {status}') - fut.get_loop().call_soon_threadsafe(fut.set_result, status) - async def async_end(self, session_id): - fut = asyncio.get_running_loop().create_future() - self.model_inst.end(partial(self.async_end_cb, fut), session_id) - await fut + pass def async_signal_cb(self, s: StreamingSemaphore): """executing on engine's signaling thread.""" @@ -575,7 +567,7 @@ async def async_stream_infer(self, input_embedding_ranges=input_embedding_ranges, gen_config=gen_config) - session = _tm.SessionParam(id=session_id, step=step, start=sequence_start, end=sequence_end) + session = _tm.SessionParam(id=session_id, step=step) inputs = _np_dict_to_tm_dict(inputs) diff --git a/src/turbomind/engine/gateway.h b/src/turbomind/engine/gateway.h index 835082204..bb3295162 100644 --- a/src/turbomind/engine/gateway.h +++ b/src/turbomind/engine/gateway.h @@ -68,13 +68,7 @@ class Gateway { { int rank = -1; - if (!r->session.start_flag) { - // route to corresponding rank - rank = seqid2rank_.find(r->session.id); - } - else { - rank = next_.fetch_add(1, std::memory_order_relaxed) % size_; - } + rank = next_.fetch_add(1, std::memory_order_relaxed) % size_; if (rank >= 0) { queues_[rank]->push({std::move(r)}); @@ -87,15 +81,10 @@ class Gateway { } } - void pop(std::vector>& infer_reqs, - std::vector>& kill_reqs, - unsigned max_infer, - bool blocking, - bool& abort, - int rank) + void + pop(std::vector>& infer_reqs, unsigned max_infer, bool blocking, bool& abort, int rank) { infer_reqs.clear(); - kill_reqs.clear(); [&] { for (int i = 0; i < size_; ++i) { @@ -110,7 +99,7 @@ class Gateway { blocking = blocking && infer_reqs.empty(); - if (queues_[rank]->pop(infer_reqs, kill_reqs, max_infer, blocking, abort)) { + if (queues_[rank]->pop(infer_reqs, max_infer, blocking, abort)) { const int group_id = rank / group_size_; // Wake all siblings for (int i = group_id * group_size_; i < (group_id + 1) * group_size_; ++i) { @@ -129,23 +118,14 @@ class Gateway { // Bind for stateful inference std::vector bind_ids; - for (const auto& r : infer_reqs) { - if (r->session.start_flag && !r->session.end_flag) { // started but not ended - bind_ids.push_back(r->session.id); - } - } + // for (const auto& r : infer_reqs) { + // if (r->session.start_flag && !r->session.end_flag) { // started but not ended + // bind_ids.push_back(r->session.id); + // } + // } if (!bind_ids.empty()) { seqid2rank_.bind(bind_ids, rank); } - - // Unbind for stateful kill - std::vector unbind_ids; - for (const auto& r : kill_reqs) { - unbind_ids.push_back(r->session.id); - } - if (!unbind_ids.empty()) { - seqid2rank_.unbind(unbind_ids, rank); - } } void cancel(std::shared_ptr r) @@ -161,19 +141,6 @@ class Gateway { } } - void kill(std::shared_ptr r) - { - if (auto rank = seqid2rank_.find(r->session.id); rank >= 0) { - queues_[rank]->kill(std::move(r)); - } - else { - TM_LOG_ERROR("[Gateway] Failed to find a binded queue for %lu", r->session.id); - notify({[r = std::move(r)] { // - UpdateState(*r, Request::kInvalid, 0); - }}); - } - } - void notify(std::vector signals) { return signal_buffer_.push(std::move(signals)); diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index 184dbe41d..e779ab30a 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -29,18 +29,6 @@ void ModelRequest::Cancel() } } -void ModelRequest::End(std::function cb, uint64_t session_id) -{ - auto r = std::make_shared(); - - r->id = r->session.id = session_id; - r->session.kill_flag = true; - - r->end_cb = std::move(cb); - - gateway_->kill(std::move(r)); -} - auto ModelRequest::Forward(InputParam param, std::function cb) -> OutputParam { inputs_ = std::make_shared(); @@ -68,7 +56,7 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output // is used instead const int max_seq_len = session_len_ + 1; const int max_out_len = std::min(output_len, session_len_) + 1; - // This does not include histroy length in interactive mode + // This does not include history length in interactive mode const int max_in_out_len = std::min(input_len + output_len, session_len_) + 1; for (auto& [k, v] : *param.tensors) { @@ -79,13 +67,18 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output add(outputs_, "sequence_length", data_type_v, kCPU, 1); if (param.gen_cfg.output_logits) { - const int len = param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len : max_out_len; + const int len = + param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len - param.session.step : max_out_len; add(outputs_, "logits", data_type_, kCPU, len, vocab_size_); + TM_LOG_INFO("[ModelRequest][forward] ID %llu, output_logits len %d", param.session.id, len); } if (param.gen_cfg.output_last_hidden_state) { - const int len = param.gen_cfg.output_last_hidden_state == GenerationConfig::kAll ? max_in_out_len : max_out_len; + const int len = param.gen_cfg.output_last_hidden_state == GenerationConfig::kAll ? + max_in_out_len - param.session.step : + max_out_len; add(outputs_, "last_hidden_state", data_type_, kCPU, len, hidden_dim_); + TM_LOG_INFO("[ModelRequest][forward] ID %llu, output_last_hidden_state len %d", param.session.id, len); } if (param.gen_cfg.output_logprobs) { @@ -105,9 +98,7 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output auto state = std::make_shared(); - if (param.session.start_flag) { - session_id_ = param.session.id; - } + session_id_ = param.session.id; r->id = param.session.id; r->session = param.session; diff --git a/src/turbomind/engine/model_request.h b/src/turbomind/engine/model_request.h index b788c0434..32228e7be 100644 --- a/src/turbomind/engine/model_request.h +++ b/src/turbomind/engine/model_request.h @@ -18,9 +18,6 @@ class ModelRequest { // Cancel running request void Cancel(); - // Reset the channel to uninitailized state, calls `notify` when done - void End(std::function cb, uint64_t session_id); - struct InputParam { std::shared_ptr tensors; diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index 31276c004..659db679c 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -81,10 +81,6 @@ struct SessionParam { uint64_t id; int step; - - bool start_flag; - bool end_flag; - bool kill_flag; }; struct RequestState { diff --git a/src/turbomind/engine/request_queue.h b/src/turbomind/engine/request_queue.h index 590578bf8..c62a756d7 100644 --- a/src/turbomind/engine/request_queue.h +++ b/src/turbomind/engine/request_queue.h @@ -28,18 +28,6 @@ class RequestQueue { cv_.notify_one(); } - void kill(std::shared_ptr r) - { - { - std::lock_guard lock{mutex_}; - if (closed_) { - throw std::runtime_error("Queue is clsoed"); - } - kill_.push_back(std::move(r)); - } - cv_.notify_one(); - } - int try_pop(std::vector>& rs, int max_rs_size, int max_count) { std::lock_guard lock{mutex_}; @@ -47,26 +35,17 @@ class RequestQueue { auto it = queue_.begin(); int count{}; while (rs.size() < max_rs_size && count < max_count && it != queue_.end()) { - if (!(*it)->session.start_flag) { - rs.push_back(std::move(*it)); - ++count; - auto tmp = it; - ++it; - queue_.erase(tmp); - } - else { - ++it; - } + rs.push_back(std::move(*it)); + ++count; + auto tmp = it; + ++it; + queue_.erase(tmp); } return count; } - bool pop(std::vector>& infer_reqs, - std::vector>& kill_reqs, - unsigned max_infer, - bool blocking, - bool& abort) + bool pop(std::vector>& infer_reqs, unsigned max_infer, bool blocking, bool& abort) { std::unique_lock lock{mutex_}; @@ -74,9 +53,7 @@ class RequestQueue { if (blocking) { cv_.wait(lock, [this] { - return !(queue_.empty() && kill_.empty()) // - || flag_->load(std::memory_order_relaxed) == expected_ // - || closed_; + return !(queue_.empty()) || flag_->load(std::memory_order_relaxed) == expected_ || closed_; }); if (closed_) { abort = true; @@ -98,9 +75,6 @@ class RequestQueue { queue_.pop_front(); } - kill_reqs.insert(kill_reqs.end(), kill_.begin(), kill_.end()); - kill_.clear(); - return is_first; } @@ -134,8 +108,6 @@ class RequestQueue { std::pmr::list> queue_; std::pmr::unsynchronized_pool_resource pool_; - std::vector> kill_; - std::mutex mutex_; std::condition_variable cv_; diff --git a/src/turbomind/models/llama/BlockTrie.cc b/src/turbomind/models/llama/BlockTrie.cc index 5f87e9828..be5ca8460 100644 --- a/src/turbomind/models/llama/BlockTrie.cc +++ b/src/turbomind/models/llama/BlockTrie.cc @@ -14,16 +14,16 @@ size_t hash(const std::vector& vec) return seed; } -BlockTrie::BlockTrie(size_t block_seq_len, std::shared_ptr block_manager, bool enable_prefix_caching): - block_seq_len_(block_seq_len), block_manager_(block_manager), enable_prefix_caching_(enable_prefix_caching) +BlockTrie::BlockTrie(size_t block_len): block_seq_len_(block_len) { root_ = std::make_shared(); } -void BlockTrie::match(Sequence& seq) +std::tuple>> BlockTrie::Match(const Sequence& seq) const { - BlockIds matched_blocks; - UniqueIds matched_unique_ids; + BlockIds matched_blocks; + UniqueIds matched_unique_ids; + std::vector> matched_nodes; std::shared_ptr curr_node = root_; int num_matched = 0; @@ -34,50 +34,54 @@ void BlockTrie::match(Sequence& seq) size_t hash_key = hash(curr_tokens); auto it = curr_node->children.find(hash_key); - if (it == curr_node->children.end()) { break; } - if (curr_tokens != it->second->tokens) { + TM_LOG_WARNING("hash key cache hit, but tokens are not the same"); break; } - matched_blocks.push_back(it->second->block_id); - matched_unique_ids.push_back(it->second->block_unique_id); + matched_blocks.emplace_back(it->second->block_id); + matched_unique_ids.emplace_back(it->second->block_unique_id); + matched_nodes.emplace_back(it->second); curr_node = it->second; num_matched += block_seq_len_; } - - if (matched_blocks.size() > 0) { - // add use count - block_manager_->Lock(matched_blocks); - block_manager_->Touch(matched_blocks); - // only consider no history blocks - seq.blocks.insert(seq.blocks.end(), matched_blocks.begin(), matched_blocks.end()); - seq.block_unique_ids.insert(seq.block_unique_ids.end(), matched_unique_ids.begin(), matched_unique_ids.end()); - } + return std::make_tuple(matched_blocks, matched_unique_ids, matched_nodes); } -void BlockTrie::cache(const Sequence& seq) +std::tuple>> BlockTrie::Cache(const Sequence& seq, + const std::vector& tokens) { - std::shared_ptr curr_node = root_; - int num_matched = 0; - int idx = 0; - BlockIds cached_blocks; + FT_CHECK(seq.status != Sequence::kCached); + FT_CHECK(tokens.size() <= seq.blocks.size() * block_seq_len_); - while (num_matched + block_seq_len_ <= seq.prompt.size()) { - std::vector curr_tokens(seq.prompt.begin() + num_matched, - seq.prompt.begin() + num_matched + block_seq_len_); - size_t hash_key = hash(curr_tokens); + std::shared_ptr curr_node = root_; + int idx = 0; - auto it = curr_node->children.find(hash_key); + BlockIds cache_block_ids; + UniqueIds cache_block_unique_ids; + std::vector> cache_nodes; + + // We don't cache the last block of the sequence, since it might not be full + // TODO(lvhan): determine wether the last block is full or not. It is not trivial + // considering chunk prefill + for (int idx = 0; idx < seq.blocks.size() - 1; ++idx) { + auto start = tokens.begin() + idx * block_seq_len_; + auto end = start + block_seq_len_; + + std::vector curr_tokens(start, end); + // TODO(lvhan): add salt to ensure the hash security + size_t hash_key = hash(curr_tokens); int block_id = seq.blocks[idx]; uint64_t block_unique_id = seq.block_unique_ids[idx]; + auto it = curr_node->children.find(hash_key); if (it != curr_node->children.end()) { if (curr_tokens != it->second->tokens) { + TM_LOG_WARNING("[BlockTrie][cache] hash key cache hit, but tokens are not the same"); break; } curr_node = it->second; @@ -91,38 +95,50 @@ void BlockTrie::cache(const Sequence& seq) node->tokens = curr_tokens; node->block_id = block_id; node->block_unique_id = block_unique_id; - node->num_matched = num_matched + block_seq_len_; curr_node->children[hash_key] = node; curr_node = node; } - - cached_blocks.push_back(curr_node->block_id); - num_matched += block_seq_len_; - idx++; + cache_block_ids.emplace_back(block_id); + cache_block_unique_ids.emplace_back(block_unique_id); + cache_nodes.emplace_back(curr_node); } - block_manager_->Touch(cached_blocks); + return std::make_tuple(cache_block_ids, cache_block_unique_ids, cache_nodes); +} + +void BlockTrie::Remove(const std::vector>& nodes, int valid_size) +{ + if (nodes.empty() || valid_size < 1) { + return; + } + // visit and remove nodes in reverse order + for (int idx = nodes.size() - 1; idx >= valid_size; --idx) { + auto child = nodes[idx]; + auto parent = nodes[idx - 1]; + auto it = parent->children.find(child->hash_key); + FT_CHECK(it != parent->children.end()); + FT_CHECK(it->second->tokens == child->tokens); + parent->children.erase(it); + } } -int BlockTrie::verify() +void BlockTrie::Prune(ValidBlockChecker checker) { - return verify_traverse(root_); + return DFSPrune(root_, checker); } -int BlockTrie::verify_traverse(std::shared_ptr& node) +void BlockTrie::DFSPrune(std::shared_ptr& node, ValidBlockChecker checker) { - int valid_count = 1; for (auto it = node->children.begin(); it != node->children.end();) { - if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) { + if (!checker(it->second->block_id, it->second->block_unique_id)) { // child invalid it = node->children.erase(it); } else { - valid_count += verify_traverse(it->second); + DFSPrune(it->second, checker); it++; } } - return valid_count; } } // namespace turbomind diff --git a/src/turbomind/models/llama/BlockTrie.h b/src/turbomind/models/llama/BlockTrie.h index b48c00061..931f0f481 100644 --- a/src/turbomind/models/llama/BlockTrie.h +++ b/src/turbomind/models/llama/BlockTrie.h @@ -17,36 +17,67 @@ struct TrieNode { std::vector tokens; int block_id; uint64_t block_unique_id; - int num_matched; }; class BlockTrie { public: - explicit BlockTrie(size_t block_len_, std::shared_ptr block_manager, bool enable_prefix_caching); + explicit BlockTrie(size_t block_len); - bool enabled() - { - return enable_prefix_caching_; - } + /** + * @brief Attempt to match cached key-value (KV) blocks for a given sequence. + * + * This function iterates the tokens of the sequence and attempts + * to match them with the cached KV blocks. If the max prefix match is found, + * it returns the IDs, unique IDs, and hash keys of the matched blocks. + * + * @param seq The sequence whose tokens are to be matched against the cached KV blocks. + * @return A tuple containing the following: + * - BlockIds: A list of IDs of the matched blocks. + * - UniqueIds: A list of unique IDs of the matched blocks. + * - std::vector>: A list of matched node + * + * @note If no blocks are matched, all containers in the returned tuple will be empty. + */ + std::tuple>> Match(const Sequence& seq) const; - // get cached blocks for sequence - void match(Sequence& seq); + /** + * @brief Cache the key-value (KV) blocks of a given sequence. + * + * This function caches the KV blocks of the specified sequence. Only valid blocks + * of a sequence whose status is NOT `Sequence::kCached` are considered + * for caching. + * + * @param seq The sequence whose KV blocks are to be cached. + * @param tokens The token list that the quence's KV blocks map + * @return A tuple containing the following: + * - BlockIds: A list of IDs of the cached blocks. + * - UniqueIds: A list of unique IDs of the cached blocks. + * - std::vector>: A list of cached node + */ + std::tuple>> Cache(const Sequence& seq, + const std::vector& tokens); - // cache computed blocks for sequence - void cache(const Sequence& seq); + /** + * @brief remove nodes[valid_size:] in a visited path from the trie tree - // remove invalid nodes, return valid count - int verify(); + * @param nodes a visited path returned by `match` or `cache` + * @param valid_size the valid number of cached blocks from the beginning of the path + * @note the visited path must be the returned value from `match` or `cache` + */ + void Remove(const std::vector>& nodes, int valid_size); + + /** + * @brief prune invalid nodes from the tree + */ + using ValidBlockChecker = std::function; + void Prune(ValidBlockChecker checker); private: - int verify_traverse(std::shared_ptr& node); + void DFSPrune(std::shared_ptr& node, ValidBlockChecker checker); private: - bool enable_prefix_caching_; size_t block_seq_len_; - std::shared_ptr block_manager_; - std::shared_ptr root_; }; diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index c16f8350c..535dc76be 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -97,7 +97,7 @@ void DropEmbeddings(const Sequence& seq) seq.input_embedding_ranges.resize(sz); } -void LlamaBatch::DisableInvalidRequests(Requests& infer_reqs, Requests& kill_reqs) +void LlamaBatch::DisableInvalidRequests(Requests& infer_reqs) { NvtxScope _("disable invalid"); @@ -126,10 +126,7 @@ void LlamaBatch::DisableInvalidRequests(Requests& infer_reqs, Requests& kill_req } } - count(kill_reqs); count(infer_reqs); - - validate(kill_reqs, "kill"); validate(infer_reqs, "infer"); // New requests that never get a chance to start @@ -169,25 +166,6 @@ void LlamaBatch::ProcessCancelRequests(std::vector& indices, std::vector& signals) -{ - for (auto& r : kill_reqs) { - if (r) { - int ec = r->ec; - if (!ec) { - if (!sequence_manager_->Erase(r->id)) { - ec = Request::kInvalid; - } - } - signals.push_back([=] { - if (r->end_cb) { - r->end_cb(ec); - } - }); - } - } -} - void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& signals) { NvtxScope scope("infer_request"); @@ -202,7 +180,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& for (const auto& r : reqs) { if (tp_rank_ == 0) { - TM_LOG_INFO("[ProcessInferRequests] Request for %ld received.", (long)r->id); + TM_LOG_INFO("[ProcessInferRequests] Request for %llu received.", r->id); } if (r->ec) { @@ -217,31 +195,12 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& continue; } - auto ptr = r->session.start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id); + auto ptr = sequence_manager_->Create(r->id); if (!ptr) { signals.push_back([r] { UpdateState(*r, Request::kInvalid, 0); }); continue; } - const int step = [&] { - int s = r->session.step; - if (s < 0) { - s = ptr->tokens.size(); - } - else if (s > ptr->tokens.size()) { - if (tp_rank_ == 0) { - TM_LOG_WARNING("[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu", s, ptr->id); - } - s = ptr->tokens.size(); - } - return s; - }(); - - if (step + input_length > session_len_) { - signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); }); - continue; - } - FT_CHECK(!state.requests[idx]); state.requests[idx] = r; @@ -249,13 +208,6 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& auto& seq = *state.sequences[idx]; - if (step < seq.tokens.size()) { - // resize sequence tokens to match step - seq.tokens.resize(step); - seq.cache_len = std::min(seq.cache_len, step); - DropEmbeddings(seq); - } - const int* input_ids = r->inputs.at("input_ids").data(); { @@ -282,10 +234,15 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& } // copy input tokens to prompt for prefix matching - if (input_length && r->session.start_flag && !r->inputs.contains("input_embedding_ranges")) { - // TODO: truncate prompt to enable prefix caching for VLM + if (input_length && !r->inputs.contains("input_embedding_ranges")) { seq.prompt.resize(input_length); std::copy_n(input_ids, input_length, seq.prompt.data()); + seq.prefix_match_end_index = input_length; + if (r->gen_cfg.output_logits || r->gen_cfg.output_last_hidden_state) { + // when output logits or output last hidden state, prefix match can only + // apply to prompts[0:step) + seq.prefix_match_end_index = r->session.step; + } } const int elem_size = byte_size(data_type_); @@ -359,36 +316,26 @@ void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& } // compute rope scaling factor - if (r->session.start_flag) { - seq.rope_theta = model_->attn_param_.rope.base; - if (model_->attn_param_.rope.type == RopeType::kDynamic) { - auto scaling_factor = model_->attn_param_.rope.factor; - if (scaling_factor >= 1.f) { // infer by current context length - auto max_seq_len = state.h_context_length[idx]; - auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings; - if (max_seq_len > max_pos_emb) { - scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); - float rope_dim = model_->attn_param_.rope.dim; - seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f)); - TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f", - (long)seq.id, - scaling_factor, - seq.rope_theta); - } + seq.rope_theta = model_->attn_param_.rope.base; + if (model_->attn_param_.rope.type == RopeType::kDynamic) { + auto scaling_factor = model_->attn_param_.rope.factor; + if (scaling_factor >= 1.f) { // infer by current context length + auto max_seq_len = state.h_context_length[idx]; + auto max_pos_emb = model_->attn_param_.rope.max_position_embeddings; + if (max_seq_len > max_pos_emb) { + scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); + float rope_dim = model_->attn_param_.rope.dim; + seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f)); + TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f", + (long)seq.id, + scaling_factor, + seq.rope_theta); } } } state.h_rope_theta[idx] = seq.rope_theta; - if (r->session.start_flag) { - // prepare to initialize random state for new sequence - h_random_seed_[idx] = r->gen_cfg.random_seed; - } - else { - // Recover device states if not a new sequence - ((curandState_t*)h_curand_state_.data())[existing_idx.size()] = *(curandState_t*)seq.random_state.data(); - existing_idx.push_back(idx); - } + h_random_seed_[idx] = r->gen_cfg.random_seed; // increment pointer idx++; @@ -987,21 +934,21 @@ void LlamaBatch::OutputLogits(const Tensor& logits, int first, int last, Generat { const auto& src_buf = logits.buffer(); const auto elem_size = byte_size(logits.dtype(), 1); - // when `is_all` is true, logits only contains last token of the sequences + // when `is_all` is false, logits only contains last token of the sequences const bool is_all = out_type == GenerationConfig::kAll; int base = 0; for (int i = first; i < last; ++i) { - const int input_len = h_input_length_buf_[i]; // input lenght for this iter + const int input_len = h_input_length_buf_[i]; // input length for this iter if (state_->requests[i]->gen_cfg.output_logits == out_type) { auto& dst_buf = state_->requests[i]->outputs.at("logits").buffer(); const int cache_len = state_->sequences[i]->cache_len; - const int history_len = state_->sequences[i]->tokens.size(); + const int history_len = state_->requests[i]->session.step; // ----------H------I-------P----------- // C C C C @@ -1011,16 +958,16 @@ void LlamaBatch::OutputLogits(const Tensor& logits, int first, int last, Generat int diff = (history_len + offset) - cache_len; - const int valid_len = input_len - std::max(0, (history_len + offset) - cache_len); + const int valid_len = input_len - std::max(0, diff); - // TM_LOG_ERROR("%d %d %d %d %d %d %d", - // history_len, - // offset, - // cache_len, - // input_len, - // valid_len, - // std::max(0, diff), - // std::max(0, -diff)); + TM_LOG_DEBUG("[output_logits] %d %d %d %d %d %d %d", + history_len, + offset, + cache_len, + input_len, + valid_len, + std::max(0, diff), + std::max(0, -diff)); if (valid_len <= 0) { continue; @@ -1030,10 +977,10 @@ void LlamaBatch::OutputLogits(const Tensor& logits, int first, int last, Generat if (is_all) { // Skip invalid tokens caused by cache miss - src_base += std::max(0, (history_len + offset) - cache_len); + src_base += std::max(0, diff); } // Skip previous chunks - int dst_base = std::max(0, cache_len - (history_len + offset)); + int dst_base = std::max(0, -diff); check_cuda_error(cudaMemcpy2DAsync(dst_buf.raw_data(dst_base * model_->vocab_size_), elem_size * model_->vocab_size_, @@ -1181,7 +1128,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } // Cache computed blocks to block trie - sequence_manager_->CacheIfEnabled(state_->sequences, batch_size); + sequence_manager_->CachePrompt(state_->sequences, batch_size); if (debug_ && tp_rank_ == 0) { for (int i = 0; i < batch_size; ++i) { @@ -1203,9 +1150,7 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) for (int i = 0; i < batch_size - g.partial; ++i) { if (state_->h_finished[i]) { ++g.finished_count; - if (!state_->requests[i]->session.end_flag) { - need_sync = true; - } + need_sync = true; } } if (need_sync) { @@ -1246,14 +1191,10 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) } } -auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Signal +auto LlamaBatch::Interrupt(int index, bool force_stop) -> Signal { if (tp_rank_ == 0) { - TM_LOG_INFO("[Interrupt] slot %d, request %lu, stop %d, end %d", - index, - (long)state_->requests[index]->id, - force_stop, - force_end); + TM_LOG_INFO("[Interrupt] slot %d, request %llu, stop %d", index, state_->requests[index]->id, force_stop); } if (debug_ && tp_rank_ == 0) { @@ -1267,32 +1208,30 @@ auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Signal TM_LOG_INFO("[Interrupt] slot %d, tokens [%s]", index, ss.str().c_str()); } - if (state_->requests[index]->session.end_flag || force_end) { - // Sequence is ending this round or a stop request is issued to end it - FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id)); - } - else { - const int output_len = state_->h_context_length[index]; - auto& seq = *state_->sequences[index]; + const int output_len = state_->h_context_length[index]; + auto& seq = *state_->sequences[index]; - // Update token IDs - seq.tokens.resize(output_len); + // Update token IDs + seq.tokens.resize(output_len); - // output_ids is updated & synced in `Finish` - const auto output_ids = state_->requests[index]->output_ids.data(); - std::copy_n(output_ids, output_len, seq.tokens.data()); + // output_ids is updated & synced in `Finish` + const auto output_ids = state_->requests[index]->output_ids.data(); + std::copy_n(output_ids, output_len, seq.tokens.data()); + // Cache the generated tokens of the sequence + sequence_manager_->CacheGeneration(seq); - // Save random state in host memory - seq.random_state.resize(sizeof(curandState_t)); - // This async copy must be synchronized by the caller - core::Copy((curandState_t*)state_->curand_state.data() + index, 1, (curandState_t*)seq.random_state.data()); + // Save random state in host memory + seq.random_state.resize(sizeof(curandState_t)); + // This async copy must be synchronized by the caller + core::Copy((curandState_t*)state_->curand_state.data() + index, 1, (curandState_t*)seq.random_state.data()); - // Set unlock flag for corresponding blocks, will be unlocked in the next `Materialize()` - sequence_manager_->UpdateAndSetUnlock(seq); - } + // Set unlock flag for corresponding blocks, will be unlocked in the next `Materialize()` + sequence_manager_->UpdateAndSetUnlock(seq); state_->sequences[index] = nullptr; + FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id)); + auto ec = std::exchange(state_->errors[index], Request::kOk); const auto len = *state_->requests[index]->sequence_length.data(); @@ -1306,7 +1245,6 @@ namespace { struct RequestData { std::vector> infer; // incoming inference request - std::vector> kill; // incoming kill request std::vector cancel; // canceled indices in current batch bool abort; @@ -1337,10 +1275,10 @@ void LlamaBatch::InternalThreadEntry() const int free_slot_count = max_batch_size_ - state_->size + g.finished_count; const bool is_empty = (free_slot_count == max_batch_size_); // Block if batch is empty AND no silbings are ready - gateway_->pop(req->infer, req->kill, free_slot_count, is_empty, req->abort, dp_rank_); + gateway_->pop(req->infer, free_slot_count, is_empty, req->abort, dp_rank_); } // Mark reqs to the same session_id as invalid (which are dangerous to the engine) - DisableInvalidRequests(req->infer, req->kill); + DisableInvalidRequests(req->infer); FindCanceledIndices(req->cancel); } @@ -1360,8 +1298,6 @@ void LlamaBatch::InternalThreadEntry() std::vector signals; - ProcessKillRequests(req->kill, signals); - // Shared `priority` field will be assigned by rank-0 ProcessInferRequests(req->infer, signals); diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 110bb519a..e82e88e7f 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -74,9 +74,7 @@ class LlamaBatch { using Requests = std::vector>; using Signal = std::function; - void DisableInvalidRequests(Requests& infer_reqs, Requests& kill_reqs); - - void ProcessKillRequests(const Requests& reqs, std::vector& signals); + void DisableInvalidRequests(Requests& infer_reqs); void ProcessInferRequests(const Requests& reqs, std::vector& signals); @@ -92,7 +90,7 @@ class LlamaBatch { void Finish(GenerationState& g, std::vector& signals); - [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false); + [[nodiscard]] Signal Interrupt(int index, bool force_stop = false); void ComputeAndOutputLogits(const Tensor& hidden_states, int first, int last); diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index 623ae3e33..f611c9f6d 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -11,6 +11,23 @@ #include namespace turbomind { +template +std::string vector2string(const std::vector& data) +{ + if (data.empty()) { + return "nil"; + } + std::stringstream ss; + + auto it = data.begin(); + ss << *it; + + for (++it; it != data.end(); ++it) { + ss << ", " << *it; + } + return ss.str(); +} + SequenceManager::SequenceManager(size_t layer_num, const BlockConfig& block_config, double block_count, @@ -27,7 +44,10 @@ SequenceManager::SequenceManager(size_t layer_num, size_t block_size = layout.block_size(layer_num); block_manager_ = std::make_shared(block_size, block_count, chunk_size, allocator, get_free_size); - block_trie_ = std::make_shared(block_config.block_len_, block_manager_, enable_prefix_caching); + if (enable_prefix_caching) { + block_trie_ = std::make_shared(block_config.block_len_); + } + TM_LOG_WARNING("[SegMgr] prefix caching is %s", enable_prefix_caching ? "enabled" : "disabled"); } const Sequence* SequenceManager::Create(uint64_t id) @@ -36,20 +56,15 @@ const Sequence* SequenceManager::Create(uint64_t id) auto it = sequences_.find(id); if (it != sequences_.end()) { if (rank_ == 0) { - TM_LOG_WARNING("[SequenceManager][Create] Removing conflicting ID %ld", (long)id); + TM_LOG_WARNING("[SeqMgr][Create] Removing conflicting ID %llu", id); } Erase(it); } it = sequences_.emplace_hint(it, id, std::move(sequence)); - return &it->second; -} - -const Sequence* SequenceManager::Get(uint64_t id) -{ - if (auto it = sequences_.find(id); it != sequences_.end()) { - return &it->second; + if (rank_ == 0) { + TM_LOG_INFO("[SeqMgr][Create] ID %llu", id); } - return nullptr; + return &it->second; } bool SequenceManager::Contains(uint64_t id) @@ -68,10 +83,17 @@ void SequenceManager::Erase(std::map::iterator& it) UpdateAndSetUnlock(seq); } // if prefix cache enabled, blocks will be shared by sequences, cannot be freed immediately - if (!block_trie_->enabled()) { + if (!block_trie_) { freed_.insert(freed_.end(), seq.blocks.begin(), seq.blocks.end()); } - it = sequences_.erase(it); + else { + // prune the invalid nodes in the tree + auto is_valid = [this](int block_id, uint64_t block_unique_id) -> bool { + return this->block_manager_->unique_id(block_id) == block_unique_id; + }; + block_trie_->Prune(is_valid); + } + (void)sequences_.erase(it); } bool SequenceManager::Erase(uint64_t id) @@ -83,21 +105,68 @@ bool SequenceManager::Erase(uint64_t id) return false; } -void SequenceManager::CacheIfEnabled(const Sequences& sequences, int active_size) +void SequenceManager::CachePrompt(const Sequences& sequences, int active_size) { - if (block_trie_->enabled()) { - block_trie_->verify(); - for (int i = 0; i < active_size; ++i) { - auto& seq = *sequences[i]; - // only cache prompt blocks - if (!seq.prompt.empty()) { - block_trie_->cache(seq); - seq.prompt.clear(); - } + if (!block_trie_) { + return; + } + for (int i = 0; i < active_size; ++i) { + auto& seq = *sequences[i]; + if (seq.cache_len > seq.prompt.size()) { + // seq prefill finished. We don't cache the prompt any longer + seq.prompt.clear(); + continue; + } + BlockIds block_ids; + UniqueIds block_unique_ids; + std::vector> nodes; + std::tie(block_ids, block_unique_ids, nodes) = block_trie_->Cache(seq, seq.prompt); + int valid = block_manager_->Verify(block_ids, block_unique_ids); + if (rank_ == 0) { + TM_LOG_INFO("[SeqMgr][CachePrompt] ID %llu, cached blocks %d, tokens %d, valid blocks %d", + seq.id, + block_ids.size(), + seq.prompt.size(), + valid); + TM_LOG_DEBUG("[SeqMgr][CachePrompt] ID %llu, cached block_ids %s, unique_ids %s", + seq.id, + vector2string(block_ids).c_str(), + vector2string(block_unique_ids).c_str()); + } + // remove invalid nodes from the path in the trie tree if there is any + if (valid < block_ids.size()) { + block_trie_->Remove(nodes, valid); } } } +void SequenceManager::CacheGeneration(const Sequence& seq) +{ + if (!block_trie_) { + return; + } + BlockIds block_ids; + UniqueIds block_unique_ids; + std::vector> nodes; + std::tie(block_ids, block_unique_ids, nodes) = block_trie_->Cache(seq, seq.tokens); + int valid = block_manager_->Verify(block_ids, block_unique_ids); + if (rank_ == 0) { + TM_LOG_INFO("[SeqMgr][CacheGeneration] ID %llu, cached blocks %d, tokens %d, valid blocks %d", + seq.id, + block_ids.size(), + seq.tokens.size(), + valid); + TM_LOG_DEBUG("[SeqMgr][CacheGeneration] ID %llu, cached block_ids %s, unique_ids %s", + seq.id, + vector2string(block_ids).c_str(), + vector2string(block_unique_ids).c_str()); + } + // remove invalid nodes from the path in the trie tree if there is any + if (valid < block_ids.size()) { + block_trie_->Remove(nodes, valid); + } +} + void SequenceManager::VerifyAndLockCached(const Sequences& sequences) { BlockIds blocks; @@ -223,8 +292,6 @@ struct Transaction { const Sequences& sequences_; Schedule& schedule_; - std::shared_ptr block_trie_; - explicit Transaction(const Sequences& sequences, int index, int block_count, int input_count, Schedule& sched): sequences_(sequences), schedule_(sched), index_(index), block_count_(block_count), input_count_(input_count) { @@ -320,25 +387,6 @@ void SequenceManager::SortByPriority(Sequences& sequences, context_lengths.swap(tmp_lengths); } -// template -// void SortByPriority(const std::vector

& priorities, Ts&... ranges) -// { -// // sort according to priority -// std::vector idxs(priorities.size()); -// std::iota(idxs.begin(), idxs.end(), 0); -// std::sort(idxs.begin(), idxs.end(), [&](int i, int j) { -// return priorities[i] < priorities[j]; // -// }); -// auto reorder = [&](auto& src) { -// auto dst = src; -// for (size_t i = 0; i < idxs.size(); ++i) { -// dst[i] = src[idxs[i]]; -// } -// src.swap(dst); -// }; -// (reorder(ranges), ...); -// } - std::vector SequenceManager::CountRequiredBlocks(const Sequences& sequences, const std::vector& context_lengths, int step_length) @@ -371,6 +419,63 @@ void SequenceManager::AssignAndActivate(const Sequences& sequences, // } } +void SequenceManager::PrefixMatch(Sequences& sequences) +{ + if (!block_trie_) { + return; + } + + for (int i = 0; i < sequences.size(); i++) { + BlockIds block_ids; + UniqueIds unique_ids; + std::vector> matched_nodes; + auto& seq = const_cast(*sequences[i]); + if (seq.cache_len != 0) { + // We only apply prefix-cache matching when seq.cache_len is 0, + // which means this seq is a brand-new sequence. + // seq.cache_len is updated after every forward iter. Refer to `LlamaBatch::Forward` + continue; + } + if (seq.prefix_match_end_index < block_seq_len_) { + continue; + } + std::tie(block_ids, unique_ids, matched_nodes) = block_trie_->Match(seq); + + int valid = block_manager_->Verify(block_ids, unique_ids); + // remove invalid nodes from the path in the trie tree if there is any + if (valid < block_ids.size()) { + block_trie_->Remove(matched_nodes, valid); + } + valid = std::min(valid, seq.prefix_match_end_index / block_seq_len_); + + BlockIds matched_ids(block_ids.begin(), block_ids.begin() + valid); + block_manager_->Lock(matched_ids); + // block_manager_->Touch(matched_ids); + if (rank_ == 0) { + TM_LOG_INFO("[SeqMgr][match] ID %llu, hit blocks %d, cache_len %d", seq.id, valid, seq.cache_len); + TM_LOG_DEBUG("[SeqMgr][match] ID %llu, hit block_ids %s, unique_ids %s", + seq.id, + vector2string(block_ids).c_str(), + vector2string(unique_ids).c_str()); + } + + FT_CHECK(seq.blocks.empty()); + seq.cache_len = valid * block_seq_len_; + seq.blocks.insert(seq.blocks.end(), block_ids.begin(), block_ids.begin() + valid); + seq.block_unique_ids.insert(seq.block_unique_ids.end(), unique_ids.begin(), unique_ids.begin() + valid); + if (rank_ == 0) { + TM_LOG_INFO("[SeqMgr][match] ID %llu, after matching, blocks %d, cache_len %d", + seq.id, + seq.blocks.size(), + seq.cache_len); + TM_LOG_DEBUG("[SeqMgr][match] ID %llu, after matching, block_ids %s, unique_ids %s", + seq.id, + vector2string(seq.blocks).c_str(), + vector2string(seq.block_unique_ids).c_str()); + } + } +} + auto SequenceManager::Materialize(Sequences sequences, std::vector context_lengths, const std::vector& priorities, @@ -391,19 +496,7 @@ auto SequenceManager::Materialize(Sequences sequences, // the blocks can still be preempted later VerifyAndLockCached(sequences); - if (block_trie_->enabled()) { - // verify blocks in trie cache - block_trie_->verify(); - - // match prefix cache - for (int i = 0; i < sequences.size(); i++) { - if (!sequences[i]->prompt.empty() && sequences[i]->blocks.empty()) { - auto& seq = const_cast(*sequences[i]); - block_trie_->match(seq); - seq.cache_len = seq.blocks.size() * block_seq_len_; - } - } - } + PrefixMatch(sequences); const int max_input_count = adjust(sequences, context_lengths); diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index 3e17ff355..306b7c886 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -26,14 +26,16 @@ struct Sequence { BlockIds blocks; UniqueIds block_unique_ids; - int input_length = 0; + int input_length = 0; // the number of tokens to be processed in each forward iter mutable std::vector prompt; - - mutable std::vector tokens; // update by user + mutable std::vector tokens; // update by user or when the sequence is finished mutable int cache_len = 0; + // since which token of a sequence that prefix match won't apply + mutable int prefix_match_end_index = 0; + // additional data kept round-to-round mutable std::vector random_state; // update by user @@ -89,8 +91,6 @@ class SequenceManager { [[nodiscard]] const Sequence* Create(uint64_t id); - [[nodiscard]] const Sequence* Get(uint64_t id); - [[nodiscard]] bool Contains(uint64_t id); [[nodiscard]] bool Erase(uint64_t id); @@ -110,8 +110,22 @@ class SequenceManager { const std::vector& priorities, int step_length, AdjustInputCount adjust); - - void CacheIfEnabled(const Sequences& sequences, int active_size); + /** @brief cache the input prompt tokens of each seq in sequences[0:active_size-1] + * + * @param sequences The sequence list + * @param active_size the number of active sequences in the list + */ + void CachePrompt(const Sequences& sequences, int active_size); + + /** @brief cache the generated tokens of a given sequence + * + * @param sequence the given sequence + * + * @note This function can only be called after the sequence finish generation + * and all tokens including the prompt tokens and generated tokens have been put to + * `seq.tokens` + */ + void CacheGeneration(const Sequence& sequence); [[nodiscard]] void* GetBlockPtr(int block_id) { @@ -143,6 +157,8 @@ class SequenceManager { const BlockIds& blocks, const UniqueIds& unique_ids); + void PrefixMatch(Sequences& sequences); + private: int block_seq_len_; int rank_; diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index a25daab2f..2def8c48d 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -291,25 +291,16 @@ struct ScopedGIL { PYBIND11_MODULE(_turbomind, m) { py::class_(m, "SessionParam") - .def(py::init([](uint64_t id, int step, bool start, bool end) { - if (!start && end) { - throw std::logic_error("unsupported arguments: start=false, end=true"); - } + .def(py::init([](uint64_t id, int step) { ft::SessionParam param{}; - param.id = id; - param.step = step; - param.start_flag = start; - param.end_flag = end; + param.id = id; + param.step = step; return param; }), "id"_a, - "step"_a, - "start"_a, - "end"_a) + "step"_a) .def_readwrite("id", &ft::SessionParam::id) - .def_readwrite("step", &ft::SessionParam::step) - .def_readwrite("start", &ft::SessionParam::start_flag) - .def_readwrite("end", &ft::SessionParam::end_flag); + .def_readwrite("step", &ft::SessionParam::step); py::class_(m, "GenerationConfig") .def(py::init()) @@ -459,15 +450,7 @@ PYBIND11_MODULE(_turbomind, m) [](ModelRequest* model_request) { model_request->Cancel(); // }, - py::call_guard()) - .def( - "end", - [](ModelRequest* model_request, std::function cb, uint64_t session_id) { - model_request->End(std::move(cb), session_id); // - }, - py::call_guard(), - "cb"_a, - "session_id"_a); + py::call_guard()); // transformer model using ft::LlamaTritonModel; diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index f14ab8249..f8b1d65c0 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -25,7 +25,8 @@ def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): yield CacheConfig(max_batches=256, block_size=block_size, num_cpu_blocks=num_cpu_blocks, - num_gpu_blocks=num_gpu_blocks) + num_gpu_blocks=num_gpu_blocks, + enable_prefix_caching=False) @pytest.fixture def scheduler_config(self):