Skip to content

Commit 1e74402

Browse files
committed
fix chart adjustment
1 parent 5559d34 commit 1e74402

File tree

8 files changed

+720
-72
lines changed

8 files changed

+720
-72
lines changed

β€Žwren-ai-service/src/globals.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def create_service_container(
262262
),
263263
conversation_service=v2_services.ConversationService(
264264
pipelines={
265-
"intent_classification": generation.IntentClassification(
265+
"intent_classification": generation.IntentClassificationV2(
266266
**pipe_components["intent_classification"],
267267
wren_ai_docs=wren_ai_docs,
268268
),
@@ -339,7 +339,7 @@ def create_service_container(
339339
"chart_generation": generation.ChartGeneration(
340340
**pipe_components["chart_generation"],
341341
),
342-
"chart_adjustment": generation.ChartAdjustment(
342+
"chart_adjustment": generation.ChartAdjustmentV2(
343343
**pipe_components["chart_adjustment"],
344344
),
345345
},

β€Žwren-ai-service/src/pipelines/generation/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from .chart_adjustment import ChartAdjustment
2+
from .chart_adjustment_v2 import ChartAdjustmentV2
23
from .chart_generation import ChartGeneration
34
from .data_assistance import DataAssistance
45
from .data_exploration_assistance import DataExplorationAssistance
56
from .followup_sql_generation import FollowUpSQLGeneration
67
from .followup_sql_generation_reasoning import FollowUpSQLGenerationReasoning
78
from .intent_classification import IntentClassification
9+
from .intent_classification_v2 import IntentClassificationV2
810
from .misleading_assistance import MisleadingAssistance
911
from .question_recommendation import QuestionRecommendation
1012
from .relationship_recommendation import RelationshipRecommendation
@@ -20,9 +22,11 @@
2022
__all__ = [
2123
"ChartGeneration",
2224
"ChartAdjustment",
25+
"ChartAdjustmentV2",
2326
"DataAssistance",
2427
"FollowUpSQLGeneration",
2528
"IntentClassification",
29+
"IntentClassificationV2",
2630
"QuestionRecommendation",
2731
"RelationshipRecommendation",
2832
"SemanticsDescription",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import logging
2+
import sys
3+
from typing import Any, Dict
4+
5+
import orjson
6+
from hamilton import base
7+
from hamilton.async_driver import AsyncDriver
8+
from haystack.components.builders.prompt_builder import PromptBuilder
9+
from langfuse.decorators import observe
10+
11+
from src.core.pipeline import BasicPipeline
12+
from src.core.provider import LLMProvider
13+
from src.pipelines.generation.utils.chart import (
14+
ChartDataPreprocessor,
15+
ChartGenerationPostProcessor,
16+
ChartGenerationResults,
17+
chart_generation_instructions,
18+
)
19+
20+
logger = logging.getLogger("wren-ai-service")
21+
22+
23+
chart_adjustment_system_prompt = f"""
24+
### TASK ###
25+
26+
You are a data analyst great at visualizing data using vega-lite! Given the user's request, SQL, sample data, sample column values, original vega-lite schema,
27+
you need to re-generate vega-lite schema in JSON and provide suitable chart type.
28+
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data, sample column values, original vega-lite schema and adjustment options.
29+
30+
{chart_generation_instructions}
31+
- If you think the user's request are not suitable for the data or not suitable for generating the chart, you can return an empty string for the schema and chart type and give reasoning to explain why.
32+
33+
### OUTPUT FORMAT ###
34+
35+
Please provide your chain of thought reasoning, chart type and the vega-lite schema in JSON format.
36+
37+
{{
38+
"reasoning": <REASON_TO_CHOOSE_THE_SCHEMA_IN_STRING_FORMATTED_IN_LANGUAGE_PROVIDED_BY_USER>,
39+
"chart_type": "line" | "multi_line" | "bar" | "pie" | "grouped_bar" | "stacked_bar" | "area" | "",
40+
"chart_schema": <VEGA_LITE_JSON_SCHEMA>
41+
}}
42+
"""
43+
44+
chart_adjustment_user_prompt_template = """
45+
### INPUT ###
46+
User's request: {{ query }}
47+
User's SQL: {{ sql }}
48+
User's Vega-Lite Schema: {{ chart_schema }}
49+
Sample Data: {{ sample_data }}
50+
Sample Column Values: {{ sample_column_values }}
51+
Language: {{ language }}
52+
53+
Please think step by step
54+
"""
55+
56+
57+
## Start of Pipeline
58+
@observe(capture_input=False)
59+
def preprocess_data(
60+
data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor
61+
) -> dict:
62+
return chart_data_preprocessor.run(data)
63+
64+
65+
@observe(capture_input=False)
66+
def prompt(
67+
query: str,
68+
sql: str,
69+
chart_schema: dict,
70+
preprocess_data: dict,
71+
language: str,
72+
prompt_builder: PromptBuilder,
73+
) -> dict:
74+
sample_data = preprocess_data.get("sample_data")
75+
sample_column_values = preprocess_data.get("sample_column_values")
76+
77+
return prompt_builder.run(
78+
query=query,
79+
sql=sql,
80+
chart_schema=chart_schema,
81+
sample_data=sample_data,
82+
sample_column_values=sample_column_values,
83+
language=language,
84+
)
85+
86+
87+
@observe(as_type="generation", capture_input=False)
88+
async def generate_chart_adjustment(prompt: dict, generator: Any) -> dict:
89+
return await generator(prompt=prompt.get("prompt"))
90+
91+
92+
@observe(capture_input=False)
93+
def post_process(
94+
generate_chart_adjustment: dict,
95+
vega_schema: Dict[str, Any],
96+
preprocess_data: dict,
97+
post_processor: ChartGenerationPostProcessor,
98+
) -> dict:
99+
return post_processor.run(
100+
generate_chart_adjustment.get("replies"),
101+
vega_schema,
102+
preprocess_data["sample_data"],
103+
)
104+
105+
106+
## End of Pipeline
107+
CHART_ADJUSTMENT_MODEL_KWARGS = {
108+
"response_format": {
109+
"type": "json_schema",
110+
"json_schema": {
111+
"name": "chart_adjustment_results",
112+
"schema": ChartGenerationResults.model_json_schema(),
113+
},
114+
}
115+
}
116+
117+
118+
class ChartAdjustmentV2(BasicPipeline):
119+
def __init__(
120+
self,
121+
llm_provider: LLMProvider,
122+
**kwargs,
123+
):
124+
self._components = {
125+
"prompt_builder": PromptBuilder(
126+
template=chart_adjustment_user_prompt_template
127+
),
128+
"generator": llm_provider.get_generator(
129+
system_prompt=chart_adjustment_system_prompt,
130+
generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS,
131+
),
132+
"chart_data_preprocessor": ChartDataPreprocessor(),
133+
"post_processor": ChartGenerationPostProcessor(),
134+
}
135+
136+
with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f:
137+
_vega_schema = orjson.loads(f.read())
138+
139+
self._configs = {
140+
"vega_schema": _vega_schema,
141+
}
142+
super().__init__(
143+
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
144+
)
145+
146+
@observe(name="Chart Adjustment")
147+
async def run(
148+
self,
149+
query: str,
150+
sql: str,
151+
chart_schema: dict,
152+
data: dict,
153+
language: str,
154+
) -> dict:
155+
logger.info("Chart Adjustment pipeline is running...")
156+
157+
return await self._pipe.execute(
158+
["post_process"],
159+
inputs={
160+
"query": query,
161+
"sql": sql,
162+
"chart_schema": chart_schema,
163+
"data": data,
164+
"language": language,
165+
**self._components,
166+
**self._configs,
167+
},
168+
)
169+
170+
171+
if __name__ == "__main__":
172+
from src.pipelines.common import dry_run_pipeline
173+
174+
dry_run_pipeline(
175+
ChartAdjustmentV2,
176+
"chart_adjustment",
177+
query="show me the dataset",
178+
sql="",
179+
chart_schema={},
180+
data={},
181+
language="English",
182+
)

β€Žwren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@
6363
6464
### User's QUERY HISTORY ###
6565
{% for history in histories %}
66-
User's Question:
66+
Question:
6767
{{ history.question }}
68-
Assistant's Response:
68+
SQL:
6969
{{ history.sql }}
7070
{% endfor %}
7171

β€Žwren-ai-service/src/pipelines/generation/intent_classification.py

+6-51
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323

2424
intent_classification_system_prompt = """
2525
### Task ###
26-
You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided.
27-
Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, `CHART`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
26+
You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema. Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
2827
2928
### Instructions ###
3029
- **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions.
@@ -36,20 +35,6 @@
3635
3736
### Intent Definitions ###
3837
39-
<DATA_EXPLORATION>
40-
**When to Use:**
41-
- The user's question is about data exploration such as asking for data details, asking for explanation of the data, asking for insights, asking for recommendations, asking for comparison, etc.
42-
43-
**Requirements:**
44-
- SQL DATA is provided and the user's question is about exploring the data.
45-
- The user's question can be answered by the SQL DATA.
46-
47-
**Examples:**
48-
- "Show me the part where the data appears abnormal"
49-
- "Please explain the data in the table"
50-
- "What's the trend of the data?"
51-
</DATA_EXPLORATION>
52-
5338
<TEXT_TO_SQL>
5439
**When to Use:**
5540
- The user's inputs are about modifying SQL from previous questions.
@@ -59,27 +44,13 @@
5944
**Requirements:**
6045
- Include specific table and column names from the schema in your reasoning or modifying SQL from previous questions.
6146
- Reference phrases from the user's inputs that clearly relate to the schema.
62-
- The SQL DATA is not provided or SQL DATA cannot answer the user's question, and the user's question can be answered given the database schema.
6347
6448
**Examples:**
6549
- "What is the total sales for last quarter?"
6650
- "Show me all customers who purchased product X."
6751
- "List the top 10 products by revenue."
6852
</TEXT_TO_SQL>
6953
70-
<CHART>
71-
**When to Use:**
72-
- The user's question is about generating a chart.
73-
74-
**Requirements:**
75-
- The user's question can be answered by the SQL DATA.
76-
- SQL DATA is provided.
77-
- Should pick a SQL from user query histories and the picked SQL should be reflected to the SQL DATA provided.
78-
79-
**Examples:**
80-
- "Show me the bar chart of the data"
81-
</CHART>
82-
8354
<GENERAL>
8455
**When to Use:**
8556
- The user seeks general information about the database schema or its overall capabilities.
@@ -112,11 +83,9 @@
11283
- The user's inputs is irrelevant to the database schema or includes SQL code.
11384
- The user's inputs lacks specific details (like table names or columns) needed to generate an SQL query.
11485
- It appears off-topic or is simply a casual conversation starter.
115-
- The user's question is about generating a chart but the SQL DATA is not provided.
11686
11787
**Requirements:**
118-
- For generating SQL: respond to users by incorporating phrases from the user's inputs that indicate the lack of relevance to the database schema.
119-
- For generating chart: respond to users that we can generate chart only if there is some data available.
88+
- Incorporate phrases from the user's inputs that indicate the lack of relevance to the database schema.
12089
12190
**Examples:**
12291
- "How are you?"
@@ -130,8 +99,7 @@
13099
{
131100
"rephrased_question": "<rephrased question in full standalone question if there are previous questions, otherwise the original question>",
132101
"reasoning": "<brief chain-of-thought reasoning (max 20 words)>",
133-
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE" | "DATA_EXPLORATION" | "CHART",
134-
"sql": "<sql query to be used for generating chart if the intent is CHART, otherwise an empty string>"
102+
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE"
135103
}
136104
"""
137105

@@ -161,18 +129,13 @@
161129
- {{doc.path}}: {{doc.content}}
162130
{% endfor %}
163131
164-
{% if sql_data %}
165-
### SQL DATA ###
166-
{{ sql_data }}
167-
{% endif %}
168-
169132
### INPUT ###
170133
{% if histories %}
171134
User's previous questions:
172135
{% for history in histories %}
173-
User's Question:
136+
Question:
174137
{{ history.question }}
175-
Assistant's Response:
138+
SQL:
176139
{{ history.sql }}
177140
{% endfor %}
178141
{% endif %}
@@ -296,7 +259,6 @@ def prompt(
296259
construct_db_schemas: list[str],
297260
histories: list[AskHistory],
298261
prompt_builder: PromptBuilder,
299-
sql_data: dict,
300262
sql_samples: Optional[list[dict]] = None,
301263
instructions: Optional[list[dict]] = None,
302264
configuration: Configuration | None = None,
@@ -313,7 +275,6 @@ def prompt(
313275
),
314276
current_time=configuration.show_current_time(),
315277
docs=wren_ai_docs,
316-
sql_data=sql_data,
317278
)
318279

319280

@@ -330,15 +291,13 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict
330291
"rephrased_question": results["rephrased_question"],
331292
"intent": results["results"],
332293
"reasoning": results["reasoning"],
333-
"sql": results["sql"],
334294
"db_schemas": construct_db_schemas,
335295
}
336296
except Exception:
337297
return {
338298
"rephrased_question": "",
339299
"intent": "TEXT_TO_SQL",
340300
"reasoning": "",
341-
"sql": "",
342301
"db_schemas": construct_db_schemas,
343302
}
344303

@@ -348,9 +307,7 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict
348307

349308
class IntentClassificationResult(BaseModel):
350309
rephrased_question: str
351-
results: Literal[
352-
"MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE", "DATA_EXPLORATION"
353-
]
310+
results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE"]
354311
reasoning: str
355312

356313

@@ -412,7 +369,6 @@ async def run(
412369
sql_samples: Optional[list[dict]] = None,
413370
instructions: Optional[list[dict]] = None,
414371
configuration: Configuration = Configuration(),
415-
sql_data: Optional[dict] = None,
416372
):
417373
logger.info("Intent Classification pipeline is running...")
418374
return await self._pipe.execute(
@@ -424,7 +380,6 @@ async def run(
424380
"sql_samples": sql_samples or [],
425381
"instructions": instructions or [],
426382
"configuration": configuration,
427-
"sql_data": sql_data or {},
428383
**self._components,
429384
**self._configs,
430385
},

0 commit comments

Comments
Β (0)