Skip to content

Update utils.py #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions fastmlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,36 @@ def lm_stream_generator(
completion_tokens = 0
empty_usage: Usage = None

import numpy as np
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
@dataclass
class MyGenerationResponse:
"""
The modified output of :func:`mlx_lm.utils.stream_generate`.

Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
logprobs []: A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
"""

text: str
token: int
logprobs: []
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: Optional[str] = None

for token in lm_stream_generate(
model, tokenizer, prompt, max_tokens=max_tokens, temp=temperature
):
Expand All @@ -448,6 +478,18 @@ def lm_stream_generator(
# Update token length info
if INCLUDE_USAGE:
completion_tokens += 1

token = MyGenerationResponse(
text=token.text,
token=token.token,
logprobs=np.array(token.logprobs).tolist(),
prompt_tokens=token.prompt_tokens,
prompt_tps=token.prompt_tps,
generation_tokens=token.generation_tokens,
generation_tps=token.generation_tps,
peak_memory=token.peak_memory,
finish_reason=token.finish_reason
)

chunk = ChatCompletionChunk(
id=f"chatcmpl-{os.urandom(4).hex()}",
Expand Down