Skip to content

Commit

Permalink
Add token usage field in OpenAI api server (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Nov 6, 2023
1 parent 615cbce commit 95d3b8c
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions chatglm_cpp/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import List, Literal, Optional, Union

import chatglm_cpp
from fastapi import FastAPI, HTTPException, Request, status
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, computed_field
from pydantic_settings import BaseSettings
from sse_starlette.sse import EventSourceResponse

Expand Down Expand Up @@ -53,12 +53,23 @@ class ChatCompletionResponseStreamChoice(BaseModel):
finish_reason: Optional[Literal["stop", "length"]] = None


class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int

@computed_field
@property
def total_tokens(self) -> int:
return self.prompt_tokens + self.completion_tokens


class ChatCompletionResponse(BaseModel):
id: str = "chatcmpl"
model: str = "default-model"
object: Literal["chat.completion", "chat.completion.chunk"]
created: int = Field(default_factory=lambda: int(time.time()))
choices: Union[List[ChatCompletionResponseChoice], List[ChatCompletionResponseStreamChoice]]
usage: Optional[ChatCompletionUsage] = None

model_config = {
"json_schema_extra": {
Expand All @@ -75,6 +86,7 @@ class ChatCompletionResponse(BaseModel):
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 17, "completion_tokens": 29, "total_tokens": 46},
}
]
}
Expand Down Expand Up @@ -141,18 +153,23 @@ async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionR
generator = stream_chat_event_publisher(history, body)
return EventSourceResponse(generator)

max_context_length = 512
output = pipeline.chat(
history=history,
max_length=body.max_tokens,
max_context_length=max_context_length,
do_sample=body.temperature > 0,
top_p=body.top_p,
temperature=body.temperature,
)
logging.info(f'prompt: "{history[-1]}", sync response: "{output}"')
prompt_tokens = len(pipeline.tokenizer.encode_history(history, max_context_length))
completion_tokens = len(pipeline.tokenizer.encode(output, body.max_tokens))

return ChatCompletionResponse(
object="chat.completion",
choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output))],
usage=ChatCompletionUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
)


Expand Down

0 comments on commit 95d3b8c

Please sign in to comment.