15
15
import json
16
16
import logging
17
17
import os
18
- import random
19
18
import uuid
20
19
import warnings
21
20
from copy import deepcopy
26
25
from huggingface_hub .utils import is_torch_available
27
26
28
27
from .tools import Tool
29
- from .utils import _is_package_available , encode_image_base64 , make_image_url
28
+ from .utils import _is_package_available , encode_image_base64 , make_image_url , parse_json_blob
30
29
31
30
32
31
if TYPE_CHECKING :
@@ -236,10 +235,34 @@ def get_clean_message_list(
236
235
return output_message_list
237
236
238
237
238
+ def get_tool_call_chat_message_from_text (text : str , tool_name_key : str , tool_arguments_key : str ) -> ChatMessage :
239
+ tool_call_dictionary , text = parse_json_blob (text )
240
+ try :
241
+ tool_name = tool_call_dictionary [tool_name_key ]
242
+ except Exception as e :
243
+ raise ValueError (
244
+ f"Key { tool_name_key = } not found in the generated tool call. Got keys: { list (tool_call_dictionary .keys ())} instead"
245
+ ) from e
246
+ tool_arguments = tool_call_dictionary .get (tool_arguments_key , None )
247
+ return ChatMessage (
248
+ role = "assistant" ,
249
+ content = text ,
250
+ tool_calls = [
251
+ ChatMessageToolCall (
252
+ id = uuid .uuid4 (),
253
+ type = "function" ,
254
+ function = ChatMessageToolCallDefinition (name = tool_name , arguments = tool_arguments ),
255
+ )
256
+ ],
257
+ )
258
+
259
+
239
260
class Model :
240
- def __init__ (self , ** kwargs ):
261
+ def __init__ (self , tool_name_key : str = "name" , tool_arguments_key : str = "arguments" , ** kwargs ):
241
262
self .last_input_token_count = None
242
263
self .last_output_token_count = None
264
+ self .tool_name_key = tool_name_key
265
+ self .tool_arguments_key = tool_arguments_key
243
266
self .kwargs = kwargs
244
267
245
268
def _prepare_completion_kwargs (
@@ -465,6 +488,104 @@ def __call__(
465
488
return message
466
489
467
490
491
+ class VLLMModel (Model ):
492
+ """Model to use [vLLM](https://docs.vllm.ai/) for fast LLM inference and serving.
493
+
494
+ Parameters:
495
+ model_id (`str`):
496
+ The Hugging Face model ID to be used for inference.
497
+ This can be a path or model identifier from the Hugging Face model hub.
498
+ """
499
+
500
+ def __init__ (self , model_id , ** kwargs ):
501
+ if not _is_package_available ("vllm" ):
502
+ raise ModuleNotFoundError ("Please install 'vllm' extra to use VLLMModel: `pip install 'smolagents[vllm]'`" )
503
+
504
+ from vllm import LLM
505
+ from vllm .transformers_utils .tokenizer import get_tokenizer
506
+
507
+ super ().__init__ (** kwargs )
508
+
509
+ self .model_id = model_id
510
+ self .model = LLM (model = model_id )
511
+ self .tokenizer = get_tokenizer (model_id )
512
+ self ._is_vlm = False # VLLMModel does not support vision models yet.
513
+
514
+ def cleanup (self ):
515
+ import gc
516
+
517
+ import torch
518
+ from vllm .distributed .parallel_state import destroy_distributed_environment , destroy_model_parallel
519
+
520
+ destroy_model_parallel ()
521
+ if self .model is not None :
522
+ # taken from https://github.com/vllm-project/vllm/issues/1908#issuecomment-2076870351
523
+ del self .model .llm_engine .model_executor .driver_worker
524
+ self .model = None
525
+ gc .collect ()
526
+ destroy_distributed_environment ()
527
+ torch .cuda .empty_cache ()
528
+
529
+ def __call__ (
530
+ self ,
531
+ messages : List [Dict [str , str ]],
532
+ stop_sequences : Optional [List [str ]] = None ,
533
+ grammar : Optional [str ] = None ,
534
+ tools_to_call_from : Optional [List [Tool ]] = None ,
535
+ ** kwargs ,
536
+ ) -> ChatMessage :
537
+ from vllm import SamplingParams
538
+
539
+ completion_kwargs = self ._prepare_completion_kwargs (
540
+ messages = messages ,
541
+ flatten_messages_as_text = (not self ._is_vlm ),
542
+ stop_sequences = stop_sequences ,
543
+ grammar = grammar ,
544
+ tools_to_call_from = tools_to_call_from ,
545
+ ** kwargs ,
546
+ )
547
+ messages = completion_kwargs .pop ("messages" )
548
+ prepared_stop_sequences = completion_kwargs .pop ("stop" , [])
549
+ tools = completion_kwargs .pop ("tools" , None )
550
+ completion_kwargs .pop ("tool_choice" , None )
551
+
552
+ if tools_to_call_from is not None :
553
+ prompt = self .tokenizer .apply_chat_template (
554
+ messages ,
555
+ tools = tools ,
556
+ add_generation_prompt = True ,
557
+ tokenize = False ,
558
+ )
559
+ else :
560
+ prompt = self .tokenizer .apply_chat_template (
561
+ messages ,
562
+ tokenize = False ,
563
+ )
564
+
565
+ sampling_params = SamplingParams (
566
+ n = kwargs .get ("n" , 1 ),
567
+ temperature = kwargs .get ("temperature" , 0.0 ),
568
+ max_tokens = kwargs .get ("max_tokens" , 2048 ),
569
+ stop = prepared_stop_sequences ,
570
+ )
571
+
572
+ out = self .model .generate (
573
+ prompt ,
574
+ sampling_params = sampling_params ,
575
+ )
576
+ output = out [0 ].outputs [0 ].text
577
+ self .last_input_token_count = len (out [0 ].prompt_token_ids )
578
+ self .last_output_token_count = len (out [0 ].outputs [0 ].token_ids )
579
+ if tools_to_call_from :
580
+ chat_message = get_tool_call_chat_message_from_text (output , self .tool_name_key , self .tool_arguments_key )
581
+ chat_message .raw = {"out" : out , "completion_kwargs" : completion_kwargs }
582
+ return chat_message
583
+ else :
584
+ return ChatMessage (
585
+ role = "assistant" , content = output , raw = {"out" : out , "completion_kwargs" : completion_kwargs }
586
+ )
587
+
588
+
468
589
class MLXModel (Model ):
469
590
"""A class to interact with models loaded using MLX on Apple silicon.
470
591
@@ -523,27 +644,7 @@ def __init__(
523
644
self .stream_generate = mlx_lm .stream_generate
524
645
self .tool_name_key = tool_name_key
525
646
self .tool_arguments_key = tool_arguments_key
526
-
527
- def _to_message (self , text , tools_to_call_from ):
528
- if tools_to_call_from :
529
- # solution for extracting tool JSON without assuming a specific model output format
530
- maybe_json = "{" + text .split ("{" , 1 )[- 1 ][::- 1 ].split ("}" , 1 )[- 1 ][::- 1 ] + "}"
531
- parsed_text = json .loads (maybe_json )
532
- tool_name = parsed_text .get (self .tool_name_key , None )
533
- tool_arguments = parsed_text .get (self .tool_arguments_key , None )
534
- if tool_name :
535
- return ChatMessage (
536
- role = "assistant" ,
537
- content = "" ,
538
- tool_calls = [
539
- ChatMessageToolCall (
540
- id = uuid .uuid4 (),
541
- type = "function" ,
542
- function = ChatMessageToolCallDefinition (name = tool_name , arguments = tool_arguments ),
543
- )
544
- ],
545
- )
546
- return ChatMessage (role = "assistant" , content = text )
647
+ self .is_vlm = False # mlx-lm doesn't support vision models
547
648
548
649
def __call__ (
549
650
self ,
@@ -554,7 +655,7 @@ def __call__(
554
655
** kwargs ,
555
656
) -> ChatMessage :
556
657
completion_kwargs = self ._prepare_completion_kwargs (
557
- flatten_messages_as_text = True , # mlx-lm doesn't support vision models
658
+ flatten_messages_as_text = ( not self . _is_vlm ),
558
659
messages = messages ,
559
660
stop_sequences = stop_sequences ,
560
661
grammar = grammar ,
@@ -583,9 +684,19 @@ def __call__(
583
684
stop_sequence_start = text .rfind (stop_sequence )
584
685
if stop_sequence_start != - 1 :
585
686
text = text [:stop_sequence_start ]
586
- return self ._to_message (text , tools_to_call_from )
687
+ found_stop_sequence = True
688
+ break
689
+ if found_stop_sequence :
690
+ break
587
691
588
- return self ._to_message (text , tools_to_call_from )
692
+ if tools_to_call_from :
693
+ chat_message = get_tool_call_chat_message_from_text (text , self .tool_name_key , self .tool_arguments_key )
694
+ chat_message .raw = {"out" : text , "completion_kwargs" : completion_kwargs }
695
+ return chat_message
696
+ else :
697
+ return ChatMessage (
698
+ role = "assistant" , content = text , raw = {"out" : text , "completion_kwargs" : completion_kwargs }
699
+ )
589
700
590
701
591
702
class TransformersModel (Model ):
@@ -779,38 +890,14 @@ def __call__(
779
890
if stop_sequences is not None :
780
891
output = remove_stop_sequences (output , stop_sequences )
781
892
782
- if tools_to_call_from is None :
783
- return ChatMessage (
784
- role = "assistant" ,
785
- content = output ,
786
- raw = {"out" : out , "completion_kwargs" : completion_kwargs },
787
- )
893
+ if tools_to_call_from :
894
+ chat_message = get_tool_call_chat_message_from_text (output , self .tool_name_key , self .tool_arguments_key )
895
+ chat_message .raw = {"out" : out , "completion_kwargs" : completion_kwargs }
896
+ return chat_message
788
897
else :
789
- if "Action:" in output :
790
- output = output .split ("Action:" , 1 )[1 ].strip ()
791
- try :
792
- start_index = output .index ("{" )
793
- end_index = output .rindex ("}" )
794
- output = output [start_index : end_index + 1 ]
795
- except Exception as e :
796
- raise Exception ("No json blob found in output!" ) from e
797
-
798
- try :
799
- parsed_output = json .loads (output )
800
- except json .JSONDecodeError as e :
801
- raise ValueError (f"Tool call '{ output } ' has an invalid JSON structure: { e } " )
802
- tool_name = parsed_output .get ("name" )
803
- tool_arguments = parsed_output .get ("arguments" )
804
898
return ChatMessage (
805
899
role = "assistant" ,
806
- content = "" ,
807
- tool_calls = [
808
- ChatMessageToolCall (
809
- id = "" .join (random .choices ("0123456789" , k = 5 )),
810
- type = "function" ,
811
- function = ChatMessageToolCallDefinition (name = tool_name , arguments = tool_arguments ),
812
- )
813
- ],
900
+ content = output ,
814
901
raw = {"out" : out , "completion_kwargs" : completion_kwargs },
815
902
)
816
903
@@ -1051,6 +1138,7 @@ def create_client(self):
1051
1138
"HfApiModel" ,
1052
1139
"LiteLLMModel" ,
1053
1140
"OpenAIServerModel" ,
1141
+ "VLLMModel" ,
1054
1142
"AzureOpenAIServerModel" ,
1055
1143
"ChatMessage" ,
1056
1144
]
0 commit comments