Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions ai_chatbot_backend/app/services/generation/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,14 @@ async def run(self) -> AsyncIterator[str]:
if hasattr(output, 'choices') and output.choices:
delta = output.choices[0].delta

if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
self.ctx.accumulated_reasoning += delta.reasoning_content
reasoning_content = (
getattr(delta, 'reasoning_content', None)
or (delta.model_extra or {}).get('reasoning_content')
or (delta.model_extra or {}).get('reasoning')
or ''
)
if reasoning_content:
self.ctx.accumulated_reasoning += reasoning_content

if hasattr(delta, 'content') and delta.content:
self.ctx.accumulated_content += delta.content
Expand Down Expand Up @@ -227,8 +233,13 @@ async def run(self) -> AsyncIterator[str]:
async for audio_event in self._interleave_audio(channels['final']):
yield audio_event
else:
# for/else: flush remaining chunks after stream ends
channels = self.ctx.previous_channels
# for/else: flush any content held back by PARTIAL_TAIL_GUARD
# Recompute channels from the final accumulated text (bypasses the guard)
if self.ctx.accumulated_reasoning:
final_text = f"<think>{self.ctx.accumulated_reasoning}</think>{self.ctx.accumulated_content}"
else:
final_text = self.ctx.accumulated_content
channels = extract_channels(final_text) or self.ctx.previous_channels
chunks = {
c: channels[c][len(self.ctx.previous_channels.get(c, "")):]
for c in channels
Expand All @@ -241,6 +252,7 @@ async def run(self) -> AsyncIterator[str]:
yield sse(ResponseDelta(seq=self.ctx.text_seq, text_channel=channel, text=chunk))
self.ctx.text_seq += 1
print(chunk, end="")
self.ctx.previous_channels = channels

# Flush remaining audio
if self.audio_response and 'final' in channels:
Expand Down
8 changes: 6 additions & 2 deletions ai_chatbot_backend/app/services/generation/model_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
SAMPLING_PARAMS = {
"temperature": 0.6,
"top_p": 0.95,
"max_tokens": 6000,
"extra_body": {"top_k": 20, "min_p": 0}
"max_tokens": 2000,
"extra_body": {
"top_k": 20,
"min_p": 0,
"chat_template_kwargs": {"thinking_budget": 512},
}
}


Expand Down
22 changes: 19 additions & 3 deletions ai_chatbot_backend/app/services/generation/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ def extract_channels(text: str) -> dict:
if not text:
return {"analysis": "", "final": ""}

# Handle case where vLLM strips the opening <think> token but passes </think> through.
# In this case the text starts with thinking content and contains </think> with no leading <think>.
if "</think>" in text and not re.match(r"^\s*<think>", text):
parts = text.split("</think>", 1)
incomplete_patterns = ["</think", "</thin", "</thi", "</th", "</t", "</", "<"]
cleaned = parts[0]
for pattern in incomplete_patterns:
if parts[0].endswith(pattern):
cleaned = parts[0][:-len(pattern)]
break
return {"analysis": cleaned.strip(), "final": parts[1].strip()}

# Only treat `<think>...</think>` as a wrapper when it is a leading block.
if re.match(r"^\s*<think>", text):
if "</think>" in text:
Expand All @@ -180,16 +192,20 @@ def extract_channels(text: str) -> dict:
return {"analysis": m.group("analysis").strip(), "final": m.group("final").strip()}

parts = text.split("</think>", 1)
return {"analysis": parts[0].strip(), "final": parts[1].strip()}
analysis = re.sub(r"^\s*<think>\s*", "", parts[0]).strip()
return {"analysis": analysis, "final": parts[1].strip()}

# Streaming: `<think>` started but hasn't closed yet.
incomplete_patterns = ["</think", "</", "<"]
# All partial suffixes of `</think>` must be listed so they are stripped before emitting.
incomplete_patterns = ["</think", "</thin", "</thi", "</th", "</t", "</", "<"]
cleaned_text = text
for pattern in incomplete_patterns:
if text.endswith(pattern):
cleaned_text = text[:-len(pattern)]
break
return {"analysis": cleaned_text.strip(), "final": ""}
# Strip the leading <think> tag so analysis is tag-free (consistent with complete case)
analysis = re.sub(r"^\s*<think>\s*", "", cleaned_text).strip()
return {"analysis": analysis, "final": ""}

# No think wrapper → everything is final (supports pure-JSON outputs).
thinking = _extract_top_level_json_string_field(text, "thinking")
Expand Down
44 changes: 33 additions & 11 deletions ai_chatbot_backend/app/services/query/reformulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from typing import Any

from openai import AsyncOpenAI

from app.config import settings
from app.dependencies.model import get_vllm_chat_client

# Singleton client for reformulation — avoids leaking a new connection pool on every call
_reformulation_client: AsyncOpenAI | None = None


def _get_reformulation_client() -> AsyncOpenAI:
global _reformulation_client
if _reformulation_client is None:
_reformulation_client = AsyncOpenAI(
base_url=settings.vllm_chat_url,
api_key=settings.vllm_api_key,
)
return _reformulation_client


# Query reformulator prompt
_QUERY_REFORMULATOR_PROMPT = (
Expand All @@ -14,8 +29,6 @@
"to align terminology and target specific topics. Include relevant constraints "
"(dates, versions, scope), and avoid adding facts not in the history. "
"Return only the rewritten query as question in plain text—no quotes, no extra text."
"# Valid channels: analysis, commentary, final. Channel must be included for every message."
"Calls to these tools must go to the commentary channel: 'functions'.<|end|>"
)


Expand Down Expand Up @@ -70,17 +83,26 @@ async def build_retrieval_query(

print(f"[DEBUG] Reformulation input ({len(request_content)} chars):\n{request_content[:2000]}...")

client = get_vllm_chat_client()
client = _get_reformulation_client()
response = await client.chat.completions.create(
model=settings.vllm_chat_model,
messages=chat,
temperature=0.6,
top_p=0.95,
max_tokens=500,
extra_body={"top_k": 20, "min_p": 0}
max_tokens=512,
timeout=30.0,
extra_body={"top_k": 20, "min_p": 0, "chat_template_kwargs": {"enable_thinking": False}},
)
msg = response.choices[0].message
content = msg.content or ""
# reasoning_content is a vLLM extension stored in model_extra by the openai SDK
reasoning = (
getattr(msg, "reasoning_content", None)
or (msg.model_extra or {}).get("reasoning_content")
or (msg.model_extra or {}).get("reasoning")
or ""
)
# vLLM with --reasoning-parser separates reasoning_content from content
# Use content directly (final response without thinking)
text = response.choices[0].message.content or ""
print(f"[INFO] Generated RAG-Query: {text.strip()}")
return text.strip()
print(f"[DEBUG] Reformulation raw: content={repr(content[:200])}, reasoning_content={repr(reasoning[:200])}")
text = content.strip() or reasoning.strip()
print(f"[INFO] Generated RAG-Query: {text[:200]}")
return text or user_message
8 changes: 4 additions & 4 deletions ai_chatbot_backend/scripts/start_vllm_servers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ BLUE='\033[0;34m'
NC='\033[0m' # No Color

# Server configurations
CHAT_MODEL="cpatonn/Qwen3-30B-A3B-Thinking-2507-AWQ-4bit"
CHAT_MODEL="/home/tai25/models/qwen3.5-27b-awq-4bit"
CHAT_PORT=8001
CHAT_GPUS="0,1"

Expand Down Expand Up @@ -187,7 +187,7 @@ main() {
# Start Chat Model Server (Port 8001) - requires 2 GPUs for tensor parallel
start_server "chat" "$CHAT_MODEL" "$CHAT_PORT" "$CHAT_GPUS" \
"--tensor-parallel-size 2" \
"--gpu-memory-utilization 0.47" \
"--gpu-memory-utilization 0.55" \
"--max-model-len 10000" \
"--max_num_seqs 32" \
"--reasoning-parser deepseek_r1"
Expand All @@ -202,7 +202,7 @@ main() {
start_server "embed" "$EMBEDDING_MODEL" "$EMBEDDING_PORT" "$EMBEDDING_GPUS" \
"--max-model-len 10000" \
"--max-num-seqs 32" \
"--gpu-memory-utilization 0.4"
"--gpu-memory-utilization 0.3"

if ! wait_for_server $EMBEDDING_PORT "Embedding"; then
log_error "Embedding server failed to start. Check tmux session for errors."
Expand All @@ -212,7 +212,7 @@ main() {

# Start Whisper Server (Port 8003)
start_server "whisper" "$WHISPER_MODEL" "$WHISPER_PORT" "$WHISPER_GPUS" \
"--gpu-memory-utilization 0.37"
"--gpu-memory-utilization 0.2"

if ! wait_for_server $WHISPER_PORT "Whisper"; then
log_error "Whisper server failed to start. Check tmux session for errors."
Expand Down
Loading