-
Notifications
You must be signed in to change notification settings - Fork 476
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Support multi-session chat (#178)
* add some dist utils * add model utils * add termio and basicstreamer * typo * fix world size * refactor chat and tested llama1 * add internlm adapter and support stoping criteria * concat with id for internlm * update docstring * update and support llama2 * typo * move docs to docs * update docstring of session manager * update docstring * update docs * fix accel none in model * fix and add test for tensor broadcast * fix session using typing to check type * add docstrings and comprehensive condition test * unit test for dist * fix session * split unittests of utils * typo * update control flow of accel * move test model * remove main in unittest * remove some log * remove some comments
- Loading branch information
1 parent
c80f3e4
commit 4bd0b48
Showing
14 changed files
with
1,081 additions
and
262 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Pytorch | ||
|
||
## Chat in command line | ||
|
||
LMDeploy support chatting with PyTorch models with submodule `lmdeploy.pytorch.chat`. | ||
|
||
This submodule allow user to chat with language model through command line, and optionally accelerate model using backends like deepspeed. | ||
|
||
**Example 1**: Chat with default setting | ||
|
||
```python | ||
python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL | ||
``` | ||
|
||
**Example 2**: Disable sampling and chat history | ||
|
||
```python | ||
python -m lmdeploy.pytorch.chat \ | ||
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ | ||
--temperature 0 --max-histroy 0 | ||
``` | ||
|
||
**Example 3**: Accelerate with deepspeed inference | ||
|
||
```python | ||
python -m lmdeploy.pytorch.chat \ | ||
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ | ||
--accel deepspeed | ||
``` | ||
|
||
Note: to use deepspeed, you need to install deepspeed, and if hope to accelerate InternLM, you need a customized version <https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0> | ||
|
||
**Example 4**: Tensor parallel the model on 2 GPUs | ||
|
||
```python | ||
deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \ | ||
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \ | ||
--accel deepspeed \ | ||
``` | ||
|
||
This module also allow the following control commands to change generation behaviors during chat. | ||
|
||
- `exit`: terminate and exit chat | ||
- `config set key=value`: change generation config `key` to `value`, e.g. config temperature=0 disable sampling for following chats | ||
- `clear`: clear chat history | ||
|
||
### Simple diagram of components | ||
|
||
```mermaid | ||
graph LR; | ||
subgraph model specific adapter | ||
p((user_input))-->tokenize-->id((input_ids))-->decorate | ||
tmpl_ids((template_ids))-->decorate; | ||
end | ||
subgraph generate | ||
model[CausalLM_model.generate]-->gen_result(("gen_result")) | ||
gen_result-->hid | ||
gen_result-->attn((attention)) | ||
end | ||
subgraph streamer | ||
model-->s[streamer]--value-->decode_single--token-->output | ||
end | ||
subgraph session_manager | ||
prepend_history-->fullid((complete_ids)); | ||
trim-->prepend_history | ||
end | ||
decorate-->prepend_history | ||
hid((history_ids))-->trim; | ||
attn-->trim; | ||
fullid-->model | ||
tokenizer((tokenizer))-->decode_single | ||
tokenizer-->tokenize | ||
p-->genconfig(GenConfig)-->model | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import logging | ||
|
||
import torch.nn as nn | ||
|
||
from .base import BasicAdapter, BasicAdapterFast | ||
from .internlm import InternLMAdapter | ||
from .llama2 import Llama2Adapter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _get_default_adapter(tokenizer): | ||
if tokenizer.is_fast: | ||
return BasicAdapterFast | ||
else: | ||
return BasicAdapter | ||
|
||
|
||
def init_adapter(model: nn.Module, tokenizer, adapter=None): | ||
if adapter is None: | ||
for v in model.modules(): | ||
if 'InternLMModel' in v.__class__.__name__: | ||
Adapter = InternLMAdapter | ||
break | ||
elif 'LlamaModel' in v.__class__.__name__: | ||
Adapter = Llama2Adapter | ||
break | ||
else: | ||
Adapter = _get_default_adapter(tokenizer) | ||
elif adapter == 'llama1': | ||
Adapter = _get_default_adapter(tokenizer) | ||
else: | ||
raise ValueError(f'Adapter {adapter} is not allowed.') | ||
|
||
logger.info(f'Using adapter {Adapter.__name__}') | ||
|
||
return Adapter(tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
"""Basic adapter suitable for general HuggingFace models.""" | ||
|
||
import logging | ||
import re | ||
|
||
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerBase, | ||
PreTrainedTokenizerFast) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BaseAdapter: | ||
"""Base class for all adapters. | ||
Note: | ||
Adapters coordinate with the session manager to prepare input_ids. | ||
The full sequence fed to the model is as follows: | ||
``` | ||
adapter.start_ids | ||
adapter.encode_and_decorate(user_input_1) | ||
output_1_generated_by_model | ||
adapter.sep_ids | ||
adapter.encode_and_decorate(user_input_2) | ||
output_2_generated_by_model | ||
adapter.sep_ids | ||
adapter.encode_and_decorate(user_input_3) | ||
``` | ||
Thus adapter is responsible for providing model specific | ||
``start_ids``, ``sep_ids``, and method to encode single prompt. | ||
""" | ||
|
||
def __init__(self, tokenizer: PreTrainedTokenizerBase): | ||
self.tokenizer = tokenizer | ||
|
||
def encode_and_decorate(self, prompt, add_special_tokens=False): | ||
"""Model specific method to encode and decorate prompt.""" | ||
raise NotImplementedError | ||
|
||
def decode(self, value): | ||
"""Model specific method to decode single value to string.""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def stopping_criteria(self): | ||
"""Model specific stopping criteria for generation.""" | ||
return None | ||
|
||
@property | ||
def start_ids(self): | ||
"""Model specific start ids.""" | ||
return [self.tokenizer.bos_token_id] | ||
|
||
@property | ||
def sep_ids(self): | ||
"""Model specific separation ids.""" | ||
return [self.tokenizer.bos_token_id] | ||
|
||
|
||
class BasicAdapter(BaseAdapter): | ||
"""Basic adapter for slow tokenizers.""" | ||
|
||
def encode_and_decorate(self, prompt, add_special_tokens=False): | ||
"""Encode prompt. | ||
Note: | ||
we leave <bos> to session manager to add. | ||
""" | ||
input_ids = self.tokenizer.encode( | ||
prompt, | ||
add_special_tokens=add_special_tokens, | ||
return_tensors='pt', | ||
) | ||
logger.debug(f'Encode {prompt} to {input_ids}') | ||
return input_ids | ||
|
||
def decode(self, value): | ||
"""Fallback when tokenizer is not fast.""" | ||
|
||
self.tokenizer: PreTrainedTokenizer | ||
|
||
tok = self.tokenizer.decode(value) | ||
return tok + ' ' | ||
|
||
|
||
class BasicAdapterFast(BaseAdapter): | ||
"""Basic adapter for slow tokenizers.""" | ||
|
||
hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$') | ||
|
||
def encode_and_decorate(self, prompt, add_special_tokens=False): | ||
"""Encode prompt. | ||
Note: | ||
we leave <bos> to session manager to add. | ||
""" | ||
input_ids = self.tokenizer.encode( | ||
prompt, | ||
add_special_tokens=add_special_tokens, | ||
return_tensors='pt', | ||
) | ||
logger.debug(f'Encode {prompt} to {input_ids}') | ||
return input_ids | ||
|
||
def decode(self, value): | ||
"""Decode with fast tokenizers.""" | ||
|
||
self.tokenizer: PreTrainedTokenizerFast | ||
|
||
tok = self.tokenizer._convert_id_to_token(value) | ||
if tok.startswith('▁'): # sentencepiece | ||
space = ' ' | ||
tok = tok[1:] | ||
else: | ||
space = '' | ||
if res := self.hex_regex.match(tok): | ||
tok = chr(int(res.group(1), 16)) | ||
if tok == '</s>' or tok == '\r': | ||
tok = '\n' | ||
|
||
tok = space + tok | ||
|
||
logger.debug(f'Decode {value} to {repr(tok)}') | ||
|
||
return tok |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import logging | ||
import re | ||
|
||
import torch | ||
from transformers import (PreTrainedTokenizerFast, StoppingCriteria, | ||
StoppingCriteriaList) | ||
|
||
from .base import BaseAdapter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InternLMStoppingCriteria(StoppingCriteria): | ||
"""Stopping criteria for HF version of InternLM.""" | ||
|
||
def __call__(self, input_ids, *args, **kwargs) -> bool: | ||
return input_ids[0, -1] in [2, 103028] | ||
|
||
|
||
class InternLMAdapter(BaseAdapter): | ||
"""Adapter for InternLM. | ||
InternLM use the following template and \n should be 13. | ||
<bos> (no actual newline here, just for better readability) | ||
<|User|>:{prompt}<eoh>\n | ||
<|Bot|>:{model_output}<eoa>\n | ||
<|User|>:{prompt}<eoh>\n | ||
<|Bot|>:{model_output}<eoa>\n | ||
... | ||
<eos> | ||
""" | ||
|
||
hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$') | ||
# ids of '<|User|>:' | ||
B_USER_ID = torch.tensor([[333, 352, 1621, 352, 27232]]) | ||
# ids of '<eoh>\n<|Bot|>:' | ||
E_USER_ID = torch.tensor([[103027, 13, 333, 352, 23845, 352, 27232]]) | ||
# ids of '<bos>' | ||
start_ids = [1] | ||
# ids of '\n' | ||
sep_ids = [13] | ||
|
||
def __init__(self, tokenizer: PreTrainedTokenizerFast): | ||
self.tokenizer = tokenizer | ||
|
||
def encode_and_decorate(self, prompt): | ||
r"""Encode prompt and decorate with template. | ||
Note: | ||
we leave <bos> and chat history for session manager to add, | ||
so we will decorate input_ids to '<|User|>:{prompt}<eoh>\n<|Bot|>:' | ||
""" | ||
input_ids = self.tokenizer.encode( | ||
prompt, | ||
add_special_tokens=False, | ||
return_tensors='pt', | ||
) | ||
# This is f'<|User|>:{prompt}<eoh>\n<|Bot|>:' | ||
# but force \n to 13 instead of 364 | ||
input_ids = torch.cat([self.B_USER_ID, input_ids, self.E_USER_ID], | ||
dim=1) | ||
return input_ids | ||
|
||
def decode(self, value): | ||
"""Decode generated tokens for InternLM.""" | ||
|
||
tok = self.tokenizer.decode(value) | ||
if res := self.hex_regex.match(tok): | ||
tok = chr(int(res.group(1), 16)) | ||
if tok == '</s>' or tok == '<eoa>' or tok == '\r': | ||
tok = '\n' | ||
|
||
logger.debug(f'Decode {value} to {repr(tok)}') | ||
|
||
return tok | ||
|
||
@property | ||
def stopping_criteria(self): | ||
return StoppingCriteriaList([InternLMStoppingCriteria()]) |
Oops, something went wrong.