Skip to content

Commit f066384

Browse files
lvhan028lzhangzz
andauthored
Fix get_ppl & get_logits (#3008)
* refactor * async interface * update perf metrics & adaptive tokens per tick * wait-free * refactor gateway * optimize throughput * add cancel cb * simplify async engine * simplify async engine * fix end session * faster synchronization * fix async engine * refactor async engine * fix semaphore * refactor inference API * remove turbomind sync interface * fix msvc build * fix msvc build * fix msvc build * add extra outputs * skip stop tokens * exit gracefully * cancel all tasks atexit * refactor profiler * fix id2step for api server * save csv * fix interactive * fix lint * fix generate_token_len * fix async_end * update pipeline ut * fix ignore eos * minor * refactor profile pipeline api * fix get_logits * update get_ppl * fix benchmark script * fix get_ppl * bring get_logits API back * update user guide * resolve reviewer's comments * update * fix * update --------- Co-authored-by: Li Zhang <[email protected]>
1 parent 086481e commit f066384

File tree

9 files changed

+239
-430
lines changed

9 files changed

+239
-430
lines changed

benchmark/profile_pipeline_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def process_request(self, requests, profiler: Profiler, temperature, top_p,
7676
top_p=top_p,
7777
top_k=top_k,
7878
ignore_eos=True,
79-
do_sample=True,
79+
do_sample=False,
8080
max_new_tokens=output_len)
8181
for _, _, output_len in requests
8282
]

docs/en/llm/pipeline.md

+43-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ You can overview the detailed pipeline API in [this](https://lmdeploy.readthedoc
66

77
## Usage
88

9-
- **An example using default parameters:**
9+
### A 'Hello, world' example
1010

1111
```python
1212
from lmdeploy import pipeline
@@ -40,7 +40,7 @@ There have been alterations to the strategy for setting the k/v cache ratio thro
4040

4141
The allocation strategy for k/v cache is changed to reserve space from the **GPU free memory** proportionally. The ratio `TurbomindEngineConfig.cache_max_entry_count` has been adjusted to 0.8 by default. If OOM error happens, similar to the method mentioned above, please consider reducing the ratio value to decrease the memory usage of the k/v cache.
4242

43-
- **An example showing how to set tensor parallel num**:
43+
### Set tensor parallelism
4444

4545
```python
4646
from lmdeploy import pipeline, TurbomindEngineConfig
@@ -52,7 +52,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
5252
print(response)
5353
```
5454

55-
- **An example for setting sampling parameters:**
55+
### Set sampling parameters
5656

5757
```python
5858
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -69,7 +69,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
6969
print(response)
7070
```
7171

72-
- **An example for OpenAI format prompt input:**
72+
### Apply OpenAI format prompt
7373

7474
```python
7575
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -93,7 +93,7 @@ response = pipe(prompts,
9393
print(response)
9494
```
9595

96-
- **An example for streaming mode:**
96+
### Apply streaming output
9797

9898
```python
9999
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -116,31 +116,60 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
116116
print(item)
117117
```
118118

119-
- **An example to cauculate logits & ppl:**
119+
### Get logits for generated tokens
120+
121+
```python
122+
from lmdeploy import pipeline, GenerationConfig
123+
124+
pipe = pipeline('internlm/internlm2_5-7b-chat')
125+
126+
gen_config=GenerationConfig(output_logits='generation'
127+
max_new_tokens=10)
128+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
129+
gen_config=gen_config)
130+
logits = [x.logits for x in response]
131+
```
132+
133+
### Get last layer's hidden states for generated tokens
134+
135+
```python
136+
from lmdeploy import pipeline, GenerationConfig
137+
138+
pipe = pipeline('internlm/internlm2_5-7b-chat')
139+
140+
gen_config=GenerationConfig(output_last_hidden_state='generation',
141+
max_new_tokens=10)
142+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
143+
gen_config=gen_config)
144+
hidden_states = [x.last_hidden_state for x in response]
145+
```
146+
147+
### Calculate ppl
120148

121149
```python
122150
from transformers import AutoTokenizer
123151
from lmdeploy import pipeline
124-
model_repoid_or_path='internlm/internlm2_5-7b-chat'
152+
153+
154+
model_repoid_or_path = 'internlm/internlm2_5-7b-chat'
125155
pipe = pipeline(model_repoid_or_path)
126156
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)
127-
128-
# logits
129157
messages = [
130158
{"role": "user", "content": "Hello, how are you?"},
131159
]
132160
input_ids = tokenizer.apply_chat_template(messages)
133-
logits = pipe.get_logits(input_ids)
134161

135-
# ppl
162+
# ppl is a list of float numbers
136163
ppl = pipe.get_ppl(input_ids)
164+
print(ppl)
137165
```
138166

139167
```{note}
140-
get_ppl returns the cross entropy loss without applying the exponential operation afterwards
168+
- When input_ids is too long, an OOM (Out Of Memory) error may occur. Please apply it with caution
169+
- get_ppl returns the cross entropy loss without applying the exponential operation afterwards
141170
```
142171

143-
- **Below is an example for pytorch backend. Please install triton first.**
172+
### Use PyTorchEngine
144173

145174
```shell
146175
pip install triton>=2.1.0
@@ -167,7 +196,7 @@ response = pipe(prompts, gen_config=gen_config)
167196
print(response)
168197
```
169198

170-
- **An example for lora.**
199+
### Inference with LoRA
171200

172201
```python
173202
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig

docs/en/multi_modal/vl_pipeline.md

+17-20
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ LMDeploy abstracts the complex inference process of multi-modal Vision-Language
44

55
The supported models are listed [here](../supported_models/supported_models.md). We genuinely invite the community to contribute new VLM support to LMDeploy. Your involvement is truly appreciated.
66

7-
This article showcases the VLM pipeline using the [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) model as a case study.
7+
This article showcases the VLM pipeline using the [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) model as a case study.
88
You'll learn about the simplest ways to leverage the pipeline and how to gradually unlock more advanced features by adjusting engine parameters and generation arguments, such as tensor parallelism, context window sizing, random sampling, and chat template customization.
99
Moreover, we will provide practical inference examples tailored to scenarios with multiple images, batch prompts etc.
1010

@@ -16,7 +16,7 @@ Using the pipeline interface to infer other VLM models is similar, with the main
1616
from lmdeploy import pipeline
1717
from lmdeploy.vl import load_image
1818

19-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')
19+
pipe = pipeline('OpenGVLab/InternVL2_5-8B')
2020

2121
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
2222
response = pipe(('describe this image', image))
@@ -30,7 +30,7 @@ In the above example, the inference prompt is a tuple structure consisting of (p
3030
```python
3131
from lmdeploy import pipeline
3232

33-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')
33+
pipe = pipeline('OpenGVLab/InternVL2_5-8B')
3434

3535
prompts = [
3636
{
@@ -53,7 +53,7 @@ Tensor paramllelism can be activated by setting the engine parameter `tp`
5353
from lmdeploy import pipeline, TurbomindEngineConfig
5454
from lmdeploy.vl import load_image
5555

56-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
56+
pipe = pipeline('OpenGVLab/InternVL2_5-8B',
5757
backend_config=TurbomindEngineConfig(tp=2))
5858

5959
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -69,7 +69,7 @@ When creating the pipeline, you can customize the size of the context window by
6969
from lmdeploy import pipeline, TurbomindEngineConfig
7070
from lmdeploy.vl import load_image
7171

72-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
72+
pipe = pipeline('OpenGVLab/InternVL2_5-8B',
7373
backend_config=TurbomindEngineConfig(session_len=8192))
7474

7575
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -85,7 +85,7 @@ You can change the default sampling parameters of pipeline by passing `Generatio
8585
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
8686
from lmdeploy.vl import load_image
8787

88-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
88+
pipe = pipeline('OpenGVLab/InternVL2_5-8B',
8989
backend_config=TurbomindEngineConfig(tp=2, session_len=8192))
9090
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)
9191
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -139,22 +139,19 @@ response = pipe(('describe this image', image))
139139
print(response)
140140
```
141141

142-
### Calculate logits
143-
144-
We provide support for custom inputs. Users can utilize 'prepare_inputs' to understand how the inputs are organized.
142+
### Output logits for generated tokens
145143

146144
```python
147-
from lmdeploy import pipeline, TurbomindEngineConfig
145+
from lmdeploy import pipeline, GenerationConfig
148146
from lmdeploy.vl import load_image
149-
pipe = pipeline('internlm/internlm-xcomposer2-7b', backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5))
147+
pipe = pipeline('OpenGVLab/InternVL2_5-8B')
150148

151-
# logits
152149
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
153-
inputs = pipe.prepare_inputs(('describe this image', image))
154-
input_ids = inputs['input_ids']
155-
embeddings = inputs['input_embeddings']
156-
embedding_ranges = inputs['input_embedding_ranges']
157-
logits = pipe.get_logits(input_ids, embeddings, embedding_ranges)
150+
151+
response = pipe(('describe this image', image),
152+
gen_config=GenerationConfig(output_logits='generation'))
153+
logits = response.logits
154+
print(logits)
158155
```
159156

160157
## Multi-images inference
@@ -165,7 +162,7 @@ When dealing with multiple images, you can put them all in one list. Keep in min
165162
from lmdeploy import pipeline, TurbomindEngineConfig
166163
from lmdeploy.vl import load_image
167164

168-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
165+
pipe = pipeline('OpenGVLab/InternVL2_5-8B',
169166
backend_config=TurbomindEngineConfig(session_len=8192))
170167

171168
image_urls=[
@@ -186,7 +183,7 @@ Conducting inference with batch prompts is quite straightforward; just place the
186183
from lmdeploy import pipeline, TurbomindEngineConfig
187184
from lmdeploy.vl import load_image
188185

189-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
186+
pipe = pipeline('OpenGVLab/InternVL2_5-8B',
190187
backend_config=TurbomindEngineConfig(session_len=8192))
191188

192189
image_urls=[
@@ -206,7 +203,7 @@ There are two ways to do the multi-turn conversations with the pipeline. One is
206203
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
207204
from lmdeploy.vl import load_image
208205

209-
pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
206+
pipe = pipeline('OpenGVLab/InternVL2_5-8B',
210207
backend_config=TurbomindEngineConfig(session_len=8192))
211208

212209
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')

docs/zh_cn/llm/pipeline.md

+46-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pipeline API 详细的接口说明,请阅读[此处](https://lmdeploy.readthed
66

77
## 使用方法
88

9-
- **使用默认参数的例子:**
9+
### "Hello, world" 示例
1010

1111
```python
1212
from lmdeploy import pipeline
@@ -40,7 +40,7 @@ LMDeploy 在研发过程中,k/v cache 比例的设定策略有变更,以下
4040

4141
分配策略改为从**空闲显存**中按比例为 k/v cache 开辟空间。默认比例值调整为 0.8。如果遇到 OOM,类似上面的方法,请酌情减少比例值,降低 k/v cache 的内存占用量
4242

43-
- **如何设置 tp:**
43+
### 设置多卡并行
4444

4545
```python
4646
from lmdeploy import pipeline, TurbomindEngineConfig
@@ -52,7 +52,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
5252
print(response)
5353
```
5454

55-
- **如何设置 sampling 参数:**
55+
### 设置随机采样参数
5656

5757
```python
5858
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -69,7 +69,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
6969
print(response)
7070
```
7171

72-
- **如何设置 OpenAI 格式输入:**
72+
### 使用 OpenAI 格式的 prompt
7373

7474
```python
7575
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -93,7 +93,7 @@ response = pipe(prompts,
9393
print(response)
9494
```
9595

96-
- **流式返回处理结果:**
96+
### 流式输出
9797

9898
```python
9999
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -116,31 +116,64 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
116116
print(item)
117117
```
118118

119-
- **计算 logits & ppl:**
119+
### 获取生成 token 的 logits
120+
121+
```python
122+
from lmdeploy import pipeline, GenerationConfig
123+
124+
pipe = pipeline('internlm/internlm2_5-7b-chat')
125+
126+
gen_config=GenerationConfig(output_logits='generation'
127+
max_new_tokens=10)
128+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
129+
gen_config=gen_config)
130+
logits = [x.logits for x in response]
131+
```
132+
133+
### 获取生成 token 最后一层的 hidden_states
134+
135+
```python
136+
from lmdeploy import pipeline, GenerationConfig
137+
138+
pipe = pipeline('internlm/internlm2_5-7b-chat')
139+
140+
gen_config=GenerationConfig(output_last_hidden_state='generation',
141+
max_new_tokens=10)
142+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
143+
gen_config=gen_config)
144+
hidden_states = [x.last_hidden_state for x in response]
145+
```
146+
147+
### 计算 ppl
120148

121149
```python
122150
from transformers import AutoTokenizer
123151
from lmdeploy import pipeline
124-
model_repoid_or_path='internlm/internlm2_5-7b-chat'
152+
153+
154+
model_repoid_or_path = 'internlm/internlm2_5-7b-chat'
125155
pipe = pipeline(model_repoid_or_path)
126156
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)
127-
128-
# logits
129157
messages = [
130158
{"role": "user", "content": "Hello, how are you?"},
131159
]
132160
input_ids = tokenizer.apply_chat_template(messages)
161+
162+
# logits is a list of tensor
133163
logits = pipe.get_logits(input_ids)
164+
print(logits)
134165

135-
# ppl
166+
# ppl is a list of float numbers
136167
ppl = pipe.get_ppl(input_ids)
168+
print(ppl)
137169
```
138170

139171
```{note}
172+
当 input_ids 过长时,可能会出现 OOM 错误,请小心应用
140173
get_ppl 返回的是 cross entropy loss,没有在之后加 exp 操作
141174
```
142175

143-
- **使用 pytorch 后端**
176+
### 使用 PyTorchEngine
144177

145178
需要先安装 triton
146179

@@ -169,7 +202,7 @@ response = pipe(prompts, gen_config=gen_config)
169202
print(response)
170203
```
171204

172-
- **一个 lora 的例子**
205+
### LoRA 模型推理
173206

174207
```python
175208
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
@@ -190,7 +223,7 @@ response = pipe(prompts, gen_config=gen_config, adapter_name='lora_name_1')
190223
print(response)
191224
```
192225

193-
## FAQs
226+
## 常见问题
194227

195228
- **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**.
196229

0 commit comments

Comments
 (0)