Skip to content

Commit 5559d34

Browse files
committed
update
1 parent aac9aa7 commit 5559d34

File tree

4 files changed

+81
-6
lines changed

4 files changed

+81
-6
lines changed

β€Žwren-ai-service/src/globals.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,12 @@ def create_service_container(
336336
**pipe_components["sql_answer"],
337337
engine_timeout=settings.engine_timeout,
338338
),
339+
"chart_generation": generation.ChartGeneration(
340+
**pipe_components["chart_generation"],
341+
),
342+
"chart_adjustment": generation.ChartAdjustment(
343+
**pipe_components["chart_adjustment"],
344+
),
339345
},
340346
max_histories=settings.max_histories,
341347
),

β€Žwren-ai-service/src/pipelines/generation/intent_classification.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
intent_classification_system_prompt = """
2525
### Task ###
2626
You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided.
27-
Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
27+
Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, `CHART`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
2828
2929
### Instructions ###
3030
- **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions.
@@ -67,6 +67,19 @@
6767
- "List the top 10 products by revenue."
6868
</TEXT_TO_SQL>
6969
70+
<CHART>
71+
**When to Use:**
72+
- The user's question is about generating a chart.
73+
74+
**Requirements:**
75+
- The user's question can be answered by the SQL DATA.
76+
- SQL DATA is provided.
77+
- Should pick a SQL from user query histories and the picked SQL should be reflected to the SQL DATA provided.
78+
79+
**Examples:**
80+
- "Show me the bar chart of the data"
81+
</CHART>
82+
7083
<GENERAL>
7184
**When to Use:**
7285
- The user seeks general information about the database schema or its overall capabilities.
@@ -99,9 +112,11 @@
99112
- The user's inputs is irrelevant to the database schema or includes SQL code.
100113
- The user's inputs lacks specific details (like table names or columns) needed to generate an SQL query.
101114
- It appears off-topic or is simply a casual conversation starter.
115+
- The user's question is about generating a chart but the SQL DATA is not provided.
102116
103117
**Requirements:**
104-
- Incorporate phrases from the user's inputs that indicate the lack of relevance to the database schema.
118+
- For generating SQL: respond to users by incorporating phrases from the user's inputs that indicate the lack of relevance to the database schema.
119+
- For generating chart: respond to users that we can generate chart only if there is some data available.
105120
106121
**Examples:**
107122
- "How are you?"
@@ -115,7 +130,8 @@
115130
{
116131
"rephrased_question": "<rephrased question in full standalone question if there are previous questions, otherwise the original question>",
117132
"reasoning": "<brief chain-of-thought reasoning (max 20 words)>",
118-
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE" | "DATA_EXPLORATION"
133+
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE" | "DATA_EXPLORATION" | "CHART",
134+
"sql": "<sql query to be used for generating chart if the intent is CHART, otherwise an empty string>"
119135
}
120136
"""
121137

@@ -314,13 +330,15 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict
314330
"rephrased_question": results["rephrased_question"],
315331
"intent": results["results"],
316332
"reasoning": results["reasoning"],
333+
"sql": results["sql"],
317334
"db_schemas": construct_db_schemas,
318335
}
319336
except Exception:
320337
return {
321338
"rephrased_question": "",
322339
"intent": "TEXT_TO_SQL",
323340
"reasoning": "",
341+
"sql": "",
324342
"db_schemas": construct_db_schemas,
325343
}
326344

β€Žwren-ai-service/src/web/v2/services/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ async def emit_content_block(
146146
content_block_label: Optional[str] = None,
147147
block_type: Literal["tool_use", "text"] = "tool_use",
148148
stream: bool = False,
149+
should_put_in_conversation_history: bool = False,
149150
):
150151
"""Emit a complete content block (start β†’ delta β†’ stop)."""
151152
# 1) start
@@ -159,6 +160,7 @@ async def emit_content_block(
159160
"type": block_type,
160161
"content_block_label": content_block_label or "",
161162
"trace_id": trace_id,
163+
"should_put_in_conversation_history": should_put_in_conversation_history,
162164
},
163165
},
164166
)
@@ -187,6 +189,7 @@ async def emit_content_block(
187189
if block_type == "json"
188190
else chunk,
189191
"trace_id": trace_id,
192+
"should_put_in_conversation_history": should_put_in_conversation_history,
190193
},
191194
},
192195
)

β€Žwren-ai-service/src/web/v2/services/conversation.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ class QuestionResult(BaseModel):
2929

3030

3131
class ConversationHistory(BaseModel):
32-
request: str
33-
response: str
32+
class ConversationRequest(BaseModel):
33+
query: str
34+
additional_info: Optional[dict] = None
35+
36+
request: ConversationRequest
37+
response: dict
3438

3539

3640
# POST /v2/conversations
@@ -62,8 +66,9 @@ def convert_conversation_history_to_ask_history(
6266
conversation_history: list[ConversationHistory],
6367
) -> list[AskHistory]:
6468
return [
65-
AskHistory(question=history.request, sql=history.response)
69+
AskHistory(question=history.request.query, sql=history.response["sql"])
6670
for history in conversation_history
71+
if history.response.get("sql")
6772
]
6873

6974

@@ -242,6 +247,29 @@ def _run_data_exploration_assistance(
242247
query_id
243248
)
244249

250+
async def _run_chart_generation(
251+
self,
252+
query: str,
253+
sql: str,
254+
data: Dict,
255+
language: str,
256+
):
257+
chart_generation_result = await self._pipelines["chart_generation"].run(
258+
query=query,
259+
sql=sql,
260+
data=data,
261+
language=language,
262+
)
263+
264+
return [
265+
{
266+
"chart_result": chart_generation_result["post_process"]["results"],
267+
"sql": sql,
268+
}
269+
], {
270+
"chart_result": chart_generation_result["post_process"]["results"],
271+
}
272+
245273
async def _run_retrieval(
246274
self,
247275
query: str,
@@ -514,6 +542,7 @@ async def start_conversation(
514542
},
515543
content_block_label="HISTORICAL_QUESTION_RETRIEVAL",
516544
block_type="tool_use",
545+
should_put_in_conversation_history=True,
517546
):
518547
sql_samples = await self._query_event_manager.emit_content_block(
519548
query_id,
@@ -566,6 +595,7 @@ async def start_conversation(
566595
"rephrased_question"
567596
)
568597
db_schemas = intent_classification_result.get("db_schemas")
598+
intent_sql = intent_classification_result.get("sql")
569599

570600
if rephrased_question:
571601
user_query = rephrased_question
@@ -631,6 +661,22 @@ async def start_conversation(
631661
block_type="text",
632662
stream=True,
633663
)
664+
elif intent == "CHART":
665+
await self._query_event_manager.emit_content_block(
666+
query_id,
667+
trace_id,
668+
index=4,
669+
emit_content_func=self._run_chart_generation,
670+
emit_content_func_kwargs={
671+
"query": user_query,
672+
"sql": intent_sql,
673+
"data": sql_data,
674+
"language": configurations.language,
675+
},
676+
content_block_label="CHART_GENERATION",
677+
block_type="tool_use",
678+
should_put_in_conversation_history=True,
679+
)
634680
else: # TEXT_TO_SQL
635681
retrieval_results = (
636682
await self._query_event_manager.emit_content_block(
@@ -721,6 +767,7 @@ async def start_conversation(
721767
},
722768
content_block_label="SQL_GENERATION",
723769
block_type="tool_use",
770+
should_put_in_conversation_history=True,
724771
)
725772
else:
726773
text_to_sql_generation_results = await self._query_event_manager.emit_content_block(
@@ -742,6 +789,7 @@ async def start_conversation(
742789
},
743790
content_block_label="SQL_GENERATION",
744791
block_type="tool_use",
792+
should_put_in_conversation_history=True,
745793
)
746794

747795
sql = ""

0 commit comments

Comments
Β (0)