Skip to content

Commit

Permalink
[Refactor] Support multi-session chat (#178)
Browse files Browse the repository at this point in the history
* add some dist utils

* add model utils

* add termio and basicstreamer

* typo

* fix world size

* refactor chat and tested llama1

* add internlm adapter and support stoping criteria

* concat with id for internlm

* update docstring

* update and support llama2

* typo

* move docs to docs

* update docstring of session manager

* update docstring

* update docs

* fix accel none in model

* fix and add test for tensor broadcast

* fix session using typing to check type

* add docstrings and comprehensive condition test

* unit test for dist

* fix session

* split unittests of utils

* typo

* update control flow of accel

* move test model

* remove main in unittest

* remove some log

* remove some comments
  • Loading branch information
wangruohui authored Aug 7, 2023
1 parent c80f3e4 commit 4bd0b48
Show file tree
Hide file tree
Showing 14 changed files with 1,081 additions and 262 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,7 @@ For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and

### Inference with PyTorch

You have to install deepspeed first before running with PyTorch.

```
pip install deepspeed
```
For detailed instructions on Inference pytorch models, see [here](docs/en/pytorch.md).

#### Single GPU

Expand All @@ -149,6 +145,12 @@ deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
--seed 0
```

You need to install deepspeed first to use this feature.

```
pip install deepspeed
```

## Quantization

In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.
Expand Down
74 changes: 74 additions & 0 deletions docs/en/pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Pytorch

## Chat in command line

LMDeploy support chatting with PyTorch models with submodule `lmdeploy.pytorch.chat`.

This submodule allow user to chat with language model through command line, and optionally accelerate model using backends like deepspeed.

**Example 1**: Chat with default setting

```python
python -m lmdeploy.pytorch.chat $PATH_TO_HF_MODEL
```

**Example 2**: Disable sampling and chat history

```python
python -m lmdeploy.pytorch.chat \
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
--temperature 0 --max-histroy 0
```

**Example 3**: Accelerate with deepspeed inference

```python
python -m lmdeploy.pytorch.chat \
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
--accel deepspeed
```

Note: to use deepspeed, you need to install deepspeed, and if hope to accelerate InternLM, you need a customized version <https://github.com/wangruohui/DeepSpeed/tree/support_internlm_0.10.0>

**Example 4**: Tensor parallel the model on 2 GPUs

```python
deepspeed --module --num_gpus 2 lmdeploy.pytorch.chat \
$PATH_TO_LLAMA_MODEL_IN_HF_FORMAT \
--accel deepspeed \
```

This module also allow the following control commands to change generation behaviors during chat.

- `exit`: terminate and exit chat
- `config set key=value`: change generation config `key` to `value`, e.g. config temperature=0 disable sampling for following chats
- `clear`: clear chat history

### Simple diagram of components

```mermaid
graph LR;
subgraph model specific adapter
p((user_input))-->tokenize-->id((input_ids))-->decorate
tmpl_ids((template_ids))-->decorate;
end
subgraph generate
model[CausalLM_model.generate]-->gen_result(("gen_result"))
gen_result-->hid
gen_result-->attn((attention))
end
subgraph streamer
model-->s[streamer]--value-->decode_single--token-->output
end
subgraph session_manager
prepend_history-->fullid((complete_ids));
trim-->prepend_history
end
decorate-->prepend_history
hid((history_ids))-->trim;
attn-->trim;
fullid-->model
tokenizer((tokenizer))-->decode_single
tokenizer-->tokenize
p-->genconfig(GenConfig)-->model
```
39 changes: 39 additions & 0 deletions lmdeploy/pytorch/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.

import logging

import torch.nn as nn

from .base import BasicAdapter, BasicAdapterFast
from .internlm import InternLMAdapter
from .llama2 import Llama2Adapter

logger = logging.getLogger(__name__)


def _get_default_adapter(tokenizer):
if tokenizer.is_fast:
return BasicAdapterFast
else:
return BasicAdapter


def init_adapter(model: nn.Module, tokenizer, adapter=None):
if adapter is None:
for v in model.modules():
if 'InternLMModel' in v.__class__.__name__:
Adapter = InternLMAdapter
break
elif 'LlamaModel' in v.__class__.__name__:
Adapter = Llama2Adapter
break
else:
Adapter = _get_default_adapter(tokenizer)
elif adapter == 'llama1':
Adapter = _get_default_adapter(tokenizer)
else:
raise ValueError(f'Adapter {adapter} is not allowed.')

logger.info(f'Using adapter {Adapter.__name__}')

return Adapter(tokenizer)
127 changes: 127 additions & 0 deletions lmdeploy/pytorch/adapters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Basic adapter suitable for general HuggingFace models."""

import logging
import re

from transformers import (PreTrainedTokenizer, PreTrainedTokenizerBase,
PreTrainedTokenizerFast)

logger = logging.getLogger(__name__)


class BaseAdapter:
"""Base class for all adapters.
Note:
Adapters coordinate with the session manager to prepare input_ids.
The full sequence fed to the model is as follows:
```
adapter.start_ids
adapter.encode_and_decorate(user_input_1)
output_1_generated_by_model
adapter.sep_ids
adapter.encode_and_decorate(user_input_2)
output_2_generated_by_model
adapter.sep_ids
adapter.encode_and_decorate(user_input_3)
```
Thus adapter is responsible for providing model specific
``start_ids``, ``sep_ids``, and method to encode single prompt.
"""

def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer

def encode_and_decorate(self, prompt, add_special_tokens=False):
"""Model specific method to encode and decorate prompt."""
raise NotImplementedError

def decode(self, value):
"""Model specific method to decode single value to string."""
raise NotImplementedError

@property
def stopping_criteria(self):
"""Model specific stopping criteria for generation."""
return None

@property
def start_ids(self):
"""Model specific start ids."""
return [self.tokenizer.bos_token_id]

@property
def sep_ids(self):
"""Model specific separation ids."""
return [self.tokenizer.bos_token_id]


class BasicAdapter(BaseAdapter):
"""Basic adapter for slow tokenizers."""

def encode_and_decorate(self, prompt, add_special_tokens=False):
"""Encode prompt.
Note:
we leave <bos> to session manager to add.
"""
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=add_special_tokens,
return_tensors='pt',
)
logger.debug(f'Encode {prompt} to {input_ids}')
return input_ids

def decode(self, value):
"""Fallback when tokenizer is not fast."""

self.tokenizer: PreTrainedTokenizer

tok = self.tokenizer.decode(value)
return tok + ' '


class BasicAdapterFast(BaseAdapter):
"""Basic adapter for slow tokenizers."""

hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')

def encode_and_decorate(self, prompt, add_special_tokens=False):
"""Encode prompt.
Note:
we leave <bos> to session manager to add.
"""
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=add_special_tokens,
return_tensors='pt',
)
logger.debug(f'Encode {prompt} to {input_ids}')
return input_ids

def decode(self, value):
"""Decode with fast tokenizers."""

self.tokenizer: PreTrainedTokenizerFast

tok = self.tokenizer._convert_id_to_token(value)
if tok.startswith('▁'): # sentencepiece
space = ' '
tok = tok[1:]
else:
space = ''
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
if tok == '</s>' or tok == '\r':
tok = '\n'

tok = space + tok

logger.debug(f'Decode {value} to {repr(tok)}')

return tok
81 changes: 81 additions & 0 deletions lmdeploy/pytorch/adapters/internlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import re

import torch
from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
StoppingCriteriaList)

from .base import BaseAdapter

logger = logging.getLogger(__name__)


class InternLMStoppingCriteria(StoppingCriteria):
"""Stopping criteria for HF version of InternLM."""

def __call__(self, input_ids, *args, **kwargs) -> bool:
return input_ids[0, -1] in [2, 103028]


class InternLMAdapter(BaseAdapter):
"""Adapter for InternLM.
InternLM use the following template and \n should be 13.
<bos> (no actual newline here, just for better readability)
<|User|>:{prompt}<eoh>\n
<|Bot|>:{model_output}<eoa>\n
<|User|>:{prompt}<eoh>\n
<|Bot|>:{model_output}<eoa>\n
...
<eos>
"""

hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
# ids of '<|User|>:'
B_USER_ID = torch.tensor([[333, 352, 1621, 352, 27232]])
# ids of '<eoh>\n<|Bot|>:'
E_USER_ID = torch.tensor([[103027, 13, 333, 352, 23845, 352, 27232]])
# ids of '<bos>'
start_ids = [1]
# ids of '\n'
sep_ids = [13]

def __init__(self, tokenizer: PreTrainedTokenizerFast):
self.tokenizer = tokenizer

def encode_and_decorate(self, prompt):
r"""Encode prompt and decorate with template.
Note:
we leave <bos> and chat history for session manager to add,
so we will decorate input_ids to '<|User|>:{prompt}<eoh>\n<|Bot|>:'
"""
input_ids = self.tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors='pt',
)
# This is f'<|User|>:{prompt}<eoh>\n<|Bot|>:'
# but force \n to 13 instead of 364
input_ids = torch.cat([self.B_USER_ID, input_ids, self.E_USER_ID],
dim=1)
return input_ids

def decode(self, value):
"""Decode generated tokens for InternLM."""

tok = self.tokenizer.decode(value)
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
if tok == '</s>' or tok == '<eoa>' or tok == '\r':
tok = '\n'

logger.debug(f'Decode {value} to {repr(tok)}')

return tok

@property
def stopping_criteria(self):
return StoppingCriteriaList([InternLMStoppingCriteria()])
Loading

0 comments on commit 4bd0b48

Please sign in to comment.