33import numpy as np
44import pytest
55import torch
6- from transformers import AutoTokenizer , GenerationConfig , LlamaForCausalLM
6+ from transformers import AutoTokenizer , LlamaConfig , LlamaForCausalLM
77
88import colossalai
9- from colossalai .inference .config import _DEFAULT_PROMPT_TEMPLATES , InferenceConfig
9+ from colossalai .inference .config import InferenceConfig
1010from colossalai .inference .core .engine import InferenceEngine
1111from colossalai .testing import parameterize , rerun_if_address_is_in_use , spawn
1212
@@ -28,69 +28,37 @@ def generate_inputs(num_sequences, min_length, max_length):
2828 return sequences
2929
3030
31- @parameterize (
32- "test_config" ,
33- [
34- {
35- "max_batch_size" : 8 ,
36- "max_output_len" : 512 ,
37- "max_input_len" : 64 ,
38- "do_sample" : False ,
39- }
40- ],
41- )
42- def check_inference_engine (test_config , use_engine = False , prompt_template = None ):
31+ @parameterize ("n_multiple" , [10 ])
32+ @parameterize ("max_batch_size" , [8 ])
33+ @parameterize ("max_input_len" , [128 ])
34+ @parameterize ("max_output_len" , [128 ])
35+ def check_inference_engine (n_multiple , max_batch_size , max_input_len , max_output_len ):
4336 setup_seed (20 )
44- max_batch_size = test_config ["max_batch_size" ]
45- max_input_len = test_config ["max_input_len" ]
46- max_output_len = test_config ["max_output_len" ]
47- do_sample = test_config ["do_sample" ]
48- top_p = 0.5
49- top_k = 50
50- tokenizer = AutoTokenizer .from_pretrained ("TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
51- model = LlamaForCausalLM .from_pretrained ("TinyLlama/TinyLlama-1.1B-Chat-v1.0" ).cuda ().half ()
37+
38+ tokenizer = AutoTokenizer .from_pretrained ("hf-internal-testing/llama-tokenizer" )
39+ model = LlamaForCausalLM (LlamaConfig (num_hidden_layers = 2 )).cuda ()
5240 model = model .eval ()
5341
54- inputs_token_ids = generate_inputs (10 * max_batch_size , min_length = 10 , max_length = max_input_len )
55-
56- if use_engine :
57- inference_config = InferenceConfig (
58- max_batch_size = max_batch_size , max_output_len = max_output_len , prompt_template = prompt_template
59- )
60- inference_engine = InferenceEngine (model , tokenizer , inference_config , verbose = True )
61- assert inference_engine .generation_config .max_new_tokens == max_output_len
62- inference_engine .add_request (prompts_token_ids = inputs_token_ids )
63- assert inference_engine .request_handler ._has_waiting ()
64- generation_config = GenerationConfig (do_sample = do_sample , top_p = top_p , top_k = top_k )
65- outputs = inference_engine .generate (generation_config = generation_config )
66- else :
67- if prompt_template :
68- # apply prompt template
69- inputs = [_DEFAULT_PROMPT_TEMPLATES [prompt_template ].format (input_text = input_text ) for input_text in inputs ]
70- tokenizer .pad_token = tokenizer .eos_token
71- tokenizer .pad_token_id = tokenizer .eos_token_id
72- inputs = tokenizer .batch_encode_plus (inputs , padding = True , return_tensors = "pt" )["input_ids" ]
73- inputs = inputs .cuda ()
74- generation_config = GenerationConfig (
75- do_sample = do_sample ,
76- top_p = top_p ,
77- top_k = top_k ,
78- pad_token_id = tokenizer .pad_token_id ,
79- max_new_tokens = max_output_len ,
80- )
81- outputs = model .generate (inputs , generation_config = generation_config )
82- outputs = tokenizer .batch_decode (outputs , skip_special_tokens = True )
83- assert len (outputs ) == 10 * max_batch_size
84-
85-
86- @parameterize ("prompt_template" , [None , "llama" ])
87- def check_continuous_batching (prompt_template ):
88- check_inference_engine (use_engine = True , prompt_template = prompt_template )
42+ inputs_token_ids = generate_inputs (
43+ n_multiple * max_batch_size , min_length = max_input_len // 2 , max_length = max_input_len
44+ )
45+ inference_config = InferenceConfig (
46+ max_batch_size = max_batch_size , max_input_len = max_input_len , max_output_len = max_output_len
47+ )
48+ inference_engine = InferenceEngine (model , tokenizer , inference_config , verbose = True )
49+ assert inference_engine .generation_config .max_new_tokens == max_output_len
50+
51+ inference_engine .add_request (prompts_token_ids = inputs_token_ids )
52+ assert inference_engine .request_handler ._has_waiting ()
53+
54+ outputs = inference_engine .generate ()
55+ assert not inference_engine .request_handler ._has_waiting ()
56+ assert len (outputs ) == n_multiple * max_batch_size
8957
9058
9159def run_dist (rank , world_size , port ):
9260 colossalai .launch (rank = rank , world_size = world_size , port = port , host = "localhost" )
93- check_continuous_batching ()
61+ check_inference_engine ()
9462
9563
9664@pytest .mark .dist
0 commit comments