diff --git a/fastmlx/utils.py b/fastmlx/utils.py index 83e1567b..52612ecb 100644 --- a/fastmlx/utils.py +++ b/fastmlx/utils.py @@ -439,6 +439,36 @@ def lm_stream_generator( completion_tokens = 0 empty_usage: Usage = None + import numpy as np + from dataclasses import dataclass + from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union + @dataclass + class MyGenerationResponse: + """ + The modified output of :func:`mlx_lm.utils.stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + logprobs []: A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + finish_reason (str): The reason the response is being sent: "length", "stop" or `None` + """ + + text: str + token: int + logprobs: [] + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + finish_reason: Optional[str] = None + for token in lm_stream_generate( model, tokenizer, prompt, max_tokens=max_tokens, temp=temperature ): @@ -448,6 +478,18 @@ def lm_stream_generator( # Update token length info if INCLUDE_USAGE: completion_tokens += 1 + + token = MyGenerationResponse( + text=token.text, + token=token.token, + logprobs=np.array(token.logprobs).tolist(), + prompt_tokens=token.prompt_tokens, + prompt_tps=token.prompt_tps, + generation_tokens=token.generation_tokens, + generation_tps=token.generation_tps, + peak_memory=token.peak_memory, + finish_reason=token.finish_reason + ) chunk = ChatCompletionChunk( id=f"chatcmpl-{os.urandom(4).hex()}",