Skip to content

Commit 273f7c7

Browse files
committed
[Feature] vllm wrapper
ghstack-source-id: 5e8c1974268bdf7f70543a4c44746dda92ea9ba2 Pull Request resolved: #2830
1 parent 919af4a commit 273f7c7

File tree

3 files changed

+223
-2
lines changed

3 files changed

+223
-2
lines changed

torchrl/envs/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,10 @@ def make_shape(shape):
942942
)
943943
if is_tensor_collection(tensor) and not is_non_tensor(tensor)
944944
else NonTensor(
945-
shape=tensor.shape, example_data=tensor.data, example_data=data.data, device=tensor.device
945+
shape=tensor.shape,
946+
example_data=tensor.data,
947+
example_data=data.data,
948+
device=tensor.device,
946949
)
947950
if is_non_tensor(tensor)
948951
else Unbounded(

torchrl/modules/llm/vllm.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
import transformers
9+
import vllm.outputs
10+
from tensordict import (
11+
from_dataclass,
12+
maybe_dense_stack,
13+
NestedKey,
14+
NonTensorData,
15+
NonTensorStack,
16+
TensorClass,
17+
TensorDict,
18+
)
19+
from tensordict.nn import (
20+
TensorDictModule as Mod,
21+
TensorDictModuleBase,
22+
TensorDictSequential as Seq,
23+
)
24+
from transformers import AutoTokenizer
25+
26+
from vllm import LLM, SamplingParams
27+
28+
CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput)
29+
30+
31+
def _maybe_clear_device(td):
32+
if td.device is None:
33+
return td
34+
return td.set(NonTensorData("_source_device"), td.device).clear_device_()
35+
36+
37+
def _maybe_set_device(td):
38+
device = td.pop("_source_device", None)
39+
if device is None:
40+
return td
41+
elif isinstance(device, NonTensorData):
42+
device: torch.device = device.data
43+
return td.to(device)
44+
45+
46+
def from_vllm(
47+
model: LLM,
48+
return_log_probs: bool = False,
49+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
50+
from_text: bool = False,
51+
device: torch.device | None = None,
52+
text_key: NestedKey = "text",
53+
generate_kwargs: dict | None = None,
54+
tokenizer_kwargs: dict | None = None,
55+
) -> TensorDictModuleBase:
56+
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
57+
module_dict = {}
58+
if device:
59+
module_dict["clear_device"] = _maybe_clear_device
60+
if from_text:
61+
if not tokenizer_kwargs:
62+
tokenizer_kwargs = {}
63+
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
64+
raise RuntimeError
65+
if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt":
66+
raise RuntimeError
67+
if tokenizer_kwargs.setdefault("padding", True) not in (True,):
68+
raise RuntimeError
69+
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
70+
raise RuntimeError
71+
module_dict["encode"] = Mod(
72+
tokenizer,
73+
in_keys=[text_key],
74+
out_keys=["tokens_in"],
75+
# method_kwargs=tokenizer_kwargs,
76+
strict=True,
77+
)
78+
79+
def to_list(tokens):
80+
if isinstance(tokens, torch.Tensor):
81+
tokens = tokens.tolist()
82+
print("tokens", tokens)
83+
return NonTensorStack(*tokens)
84+
85+
module_dict["to_list"] = Mod(
86+
to_list,
87+
in_keys=[("tokens_in", "input_ids")],
88+
out_keys=[("tokens_in", "input_ids_list")],
89+
)
90+
91+
if generate_kwargs is None:
92+
generate_kwargs = {"detokenize": False, "prompt_logprobs": 1, "logprobs": 1}
93+
sampling_params = SamplingParams(**generate_kwargs)
94+
95+
module_dict["generate"] = Mod(
96+
model,
97+
method="generate",
98+
method_kwargs={"sampling_params": sampling_params},
99+
in_keys={
100+
"prompt_token_ids": ("tokens_in", "input_ids_list"),
101+
# "attention_mask": ("tokens_in", "attention_mask"),
102+
},
103+
out_keys=["tokens_out"],
104+
out_to_in_map=True,
105+
strict=True,
106+
)
107+
108+
def get_output_tokens_and_log_probs(td):
109+
td["tokens_out"] = RequestOutput_tc.from_request_output(td["tokens_out"])
110+
td["output_tokens"] = td["tokens_out"].outputs.token_ids
111+
td["log_probs"] = td["tokens_out"].outputs.token_ids
112+
return td
113+
114+
module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs
115+
116+
if from_text:
117+
module_dict["to_list_decode"] = Mod(
118+
to_list, in_keys=[("output_tokens")], out_keys=[("output_tokens_list")]
119+
)
120+
module_dict["decode"] = Mod(
121+
tokenizer.batch_decode,
122+
in_keys=["output_tokens_list"],
123+
out_keys=["action"],
124+
)
125+
126+
if device:
127+
module_dict["to_source_device"] = _maybe_set_device
128+
129+
return Seq(module_dict)
130+
131+
132+
class RequestOutput_tc(TensorClass["nocast"]):
133+
request_id: str
134+
prompt: str
135+
prompt_token_ids: str
136+
prompt_logprobs: str
137+
outputs: str
138+
finished: str
139+
metrics: str
140+
lora_request: str
141+
encoder_prompt: str
142+
encoder_prompt_token_ids: str
143+
num_cached_tokens: str
144+
145+
def __post_init__(self):
146+
def postproc(output):
147+
print("local", output)
148+
149+
def get_logprob(output):
150+
t = []
151+
for v, tid in zip(output.logprobs, output.token_ids):
152+
t.append(
153+
v[tid]["logprob"] if v[tid].get("logprob") is not None else 0.0
154+
)
155+
return torch.tensor(t)
156+
157+
output.logprobs = get_logprob(output)
158+
print("token ids before transform", output.token_ids)
159+
output.token_ids = torch.tensor(output.token_ids)
160+
return output
161+
162+
if isinstance(self.outputs, list):
163+
outputs = self.outputs
164+
outputs = [
165+
postproc(from_dataclass(output, dest_cls=CompletionOutput_tc))
166+
for output in outputs
167+
]
168+
if len(outputs) == 1:
169+
self.outputs = outputs[0]
170+
else:
171+
self.outputs = torch.stack(outputs)
172+
self.prompt_logprobs = torch.tensor(
173+
[
174+
v[tid].logprob if v is not None else 0.0
175+
for v, tid in zip(self.prompt_logprobs, self.prompt_token_ids)
176+
]
177+
)
178+
self.prompt_token_ids = torch.tensor(self.prompt_token_ids)
179+
self.num_cached_tokens = torch.tensor(self.num_cached_tokens)
180+
181+
@classmethod
182+
def from_request_output(cls, requests):
183+
out = maybe_dense_stack(
184+
[
185+
cls(
186+
request_id=request.request_id,
187+
prompt=request.prompt,
188+
prompt_token_ids=request.prompt_token_ids,
189+
prompt_logprobs=request.prompt_logprobs,
190+
outputs=request.outputs,
191+
finished=request.finished,
192+
metrics=request.metrics,
193+
lora_request=request.lora_request,
194+
encoder_prompt=request.encoder_prompt,
195+
encoder_prompt_token_ids=request.encoder_prompt_token_ids,
196+
num_cached_tokens=request.num_cached_tokens,
197+
)
198+
for request in requests
199+
]
200+
)
201+
print("result of from_request_output", out)
202+
return out
203+
204+
205+
if __name__ == "__main__":
206+
max_seq_length = 50000
207+
model_name = "Qwen/Qwen2.5-7B-Instruct"
208+
model = LLM(model_name, skip_tokenizer_init=True, device="cuda:0")
209+
model.llm_engine.model_executor.driver_worker.worker.model_runner.model.sampler.include_gpu_probs_tensor = (
210+
True
211+
)
212+
tokenizer = AutoTokenizer.from_pretrained(model_name, device="cuda:0")
213+
# tokenizer.padding_side = "left"
214+
m = from_vllm(model, tokenizer=tokenizer, from_text=True, device="cuda:0")
215+
print(m(TensorDict(text=NonTensorStack("a text is a text", "another text"))))

torchrl/modules/tensordict_module/actors.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2371,7 +2371,10 @@ def forward(
23712371
action_entry = parent_td.get(action_key_orig[-1], None)
23722372
if action_entry is None:
23732373
raise self._NO_INIT_ERR
2374-
if self.n_steps is not None and action_entry.shape[parent_td.ndim] != self.n_steps:
2374+
if (
2375+
self.n_steps is not None
2376+
and action_entry.shape[parent_td.ndim] != self.n_steps
2377+
):
23752378
raise RuntimeError(
23762379
f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). "
23772380
f"The action shape was {action_entry.shape}."

0 commit comments

Comments
 (0)