Skip to content

feat(wren-ai-service): flexible chart generation #1652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 39 additions & 51 deletions wren-ai-service/src/pipelines/generation/chart_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
from typing import Any, Dict

import orjson
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
Expand All @@ -14,34 +13,36 @@
ChartDataPreprocessor,
ChartGenerationPostProcessor,
ChartGenerationResults,
chart_generation_instructions,
)
from src.web.v1.services.chart_adjustment import ChartAdjustmentOption

logger = logging.getLogger("wren-ai-service")


chart_adjustment_system_prompt = f"""
def gen_chart_adjustment_system_prompt() -> str:
return """
### TASK ###

You are a data analyst great at visualizing data using vega-lite! Given the user's question, SQL, sample data, sample column values, original vega-lite schema and adjustment options,
you need to re-generate vega-lite schema in JSON and provide suitable chart type.
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data, sample column values, original vega-lite schema and adjustment options.
You are a data analyst great at generating data visualization using vega-lite! Given the user's question, SQL, sample data, sample column values, original vega-lite schema and adjustment command,
you need to think about the best chart type and generate corresponding vega-lite schema in JSON format.
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data, sample column values, original vega-lite schema and adjustment command.

{chart_generation_instructions}
- If you think the adjustment options are not suitable for the data, you can return an empty string for the schema and chart type and give reasoning to explain why.
### INSTRUCTIONS ###

- You need to generate the new vega-lite schema based on the adjustment command and the original vega-lite schema.
- If you think the adjustment command is not suitable for the data, you can return an empty string for the schema and chart type and give reasoning to explain why.
- If the user provides an image, you need to use the image as reference to generate a new chart schema that follows user's adjustment command.

### OUTPUT FORMAT ###

Please provide your chain of thought reasoning, chart type and the vega-lite schema in JSON format.
Please provide your chain of thought reasoning, and the vega-lite schema in JSON format.

{{
{
"reasoning": <REASON_TO_CHOOSE_THE_SCHEMA_IN_STRING_FORMATTED_IN_LANGUAGE_PROVIDED_BY_USER>,
"chart_type": "line" | "multi_line" | "bar" | "pie" | "grouped_bar" | "stacked_bar" | "area" | "",
"chart_schema": <VEGA_LITE_JSON_SCHEMA>
}}
}
"""


chart_adjustment_user_prompt_template = """
### INPUT ###
Original Question: {{ query }}
Expand All @@ -51,25 +52,7 @@
Sample Column Values: {{ sample_column_values }}
Language: {{ language }}

Adjustment Options:
- Chart Type: {{ adjustment_option.chart_type }}
{% if adjustment_option.chart_type != "pie" %}
{% if adjustment_option.x_axis %}
- X Axis: {{ adjustment_option.x_axis }}
{% endif %}
{% if adjustment_option.y_axis %}
- Y Axis: {{ adjustment_option.y_axis }}
{% endif %}
{% endif %}
{% if adjustment_option.x_offset and adjustment_option.chart_type == "grouped_bar" %}
- X Offset: {{ adjustment_option.x_offset }}
{% endif %}
{% if adjustment_option.color and adjustment_option.chart_type != "area" %}
- Color: {{ adjustment_option.color }}
{% endif %}
{% if adjustment_option.theta and adjustment_option.chart_type == "pie" %}
- Theta: {{ adjustment_option.theta }}
{% endif %}
Adjustment Command: {{ adjustment_command }}

Please think step by step
"""
Expand All @@ -78,7 +61,8 @@
## Start of Pipeline
@observe(capture_input=False)
def preprocess_data(
data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor
data: Dict[str, Any],
chart_data_preprocessor: ChartDataPreprocessor,
) -> dict:
return chart_data_preprocessor.run(data)

Expand All @@ -87,7 +71,7 @@ def preprocess_data(
def prompt(
query: str,
sql: str,
adjustment_option: ChartAdjustmentOption,
adjustment_command: str,
chart_schema: dict,
preprocess_data: dict,
language: str,
Expand All @@ -99,7 +83,7 @@ def prompt(
return prompt_builder.run(
query=query,
sql=sql,
adjustment_option=adjustment_option,
adjustment_command=adjustment_command,
chart_schema=chart_schema,
sample_data=sample_data,
sample_column_values=sample_column_values,
Expand All @@ -108,21 +92,26 @@ def prompt(


@observe(as_type="generation", capture_input=False)
async def generate_chart_adjustment(prompt: dict, generator: Any) -> dict:
return await generator(prompt=prompt.get("prompt"))
async def generate_chart_adjustment(
prompt: dict, image_url: str, generator: Any
) -> dict:
return await generator(prompt=prompt.get("prompt"), image_url=image_url)


@observe(capture_input=False)
def post_process(
generate_chart_adjustment: dict,
vega_schema: Dict[str, Any],
remove_data_from_chart_schema: bool,
preprocess_data: dict,
data_provided: bool,
post_processor: ChartGenerationPostProcessor,
) -> dict:
return post_processor.run(
generate_chart_adjustment.get("replies"),
vega_schema,
preprocess_data["sample_data"],
preprocess_data["raw_data"]
if data_provided
else preprocess_data["sample_data"],
remove_data_from_chart_schema=remove_data_from_chart_schema,
)


Expand All @@ -149,19 +138,13 @@ def __init__(
template=chart_adjustment_user_prompt_template
),
"generator": llm_provider.get_generator(
system_prompt=chart_adjustment_system_prompt,
system_prompt=gen_chart_adjustment_system_prompt(),
generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS,
),
"chart_data_preprocessor": ChartDataPreprocessor(),
"post_processor": ChartGenerationPostProcessor(),
}

with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f:
_vega_schema = orjson.loads(f.read())

self._configs = {
"vega_schema": _vega_schema,
}
super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
Expand All @@ -171,10 +154,13 @@ async def run(
self,
query: str,
sql: str,
adjustment_option: ChartAdjustmentOption,
adjustment_command: str,
chart_schema: dict,
data: dict,
language: str,
remove_data_from_chart_schema: bool = True,
data_provided: bool = False,
image_url: str = "",
) -> dict:
logger.info("Chart Adjustment pipeline is running...")

Expand All @@ -183,12 +169,14 @@ async def run(
inputs={
"query": query,
"sql": sql,
"adjustment_option": adjustment_option,
"adjustment_command": adjustment_command,
"chart_schema": chart_schema,
"data": data,
"language": language,
"remove_data_from_chart_schema": remove_data_from_chart_schema,
"data_provided": data_provided,
"image_url": image_url,
**self._components,
**self._configs,
},
)

Expand All @@ -201,7 +189,7 @@ async def run(
"chart_adjustment",
query="show me the dataset",
sql="",
adjustment_option=ChartAdjustmentOption(),
adjustment_command="",
chart_schema={},
# data={},
language="English",
Expand Down
42 changes: 18 additions & 24 deletions wren-ai-service/src/pipelines/generation/chart_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import sys
from typing import Any, Dict, Optional

import orjson
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack.components.builders.prompt_builder import PromptBuilder
Expand All @@ -14,30 +13,29 @@
ChartDataPreprocessor,
ChartGenerationPostProcessor,
ChartGenerationResults,
chart_generation_instructions,
)

logger = logging.getLogger("wren-ai-service")

chart_generation_system_prompt = f"""

def gen_chart_gen_system_prompt() -> str:
return """
### TASK ###

You are a data analyst great at visualizing data using vega-lite! Given the user's question, SQL, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type.
You are a data analyst great at generating data visualization using vega-lite! Given the user's question, SQL, sample data and sample column values, you need to think about the best chart type and generate correspondingvega-lite schema in JSON format.
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data and sample column values.

{chart_generation_instructions}

### OUTPUT FORMAT ###

Please provide your chain of thought reasoning, chart type and the vega-lite schema in JSON format.
Please provide your chain of thought reasoning, and the vega-lite schema in JSON format.

{{
{
"reasoning": <REASON_TO_CHOOSE_THE_SCHEMA_IN_STRING_FORMATTED_IN_LANGUAGE_PROVIDED_BY_USER>,
"chart_type": "line" | "multi_line" | "bar" | "pie" | "grouped_bar" | "stacked_bar" | "area" | "",
"chart_schema": <VEGA_LITE_JSON_SCHEMA>
}}
}
"""


chart_generation_user_prompt_template = """
### INPUT ###
Question: {{ query }}
Expand All @@ -53,7 +51,8 @@
## Start of Pipeline
@observe(capture_input=False)
def preprocess_data(
data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor
data: Dict[str, Any],
chart_data_preprocessor: ChartDataPreprocessor,
) -> dict:
return chart_data_preprocessor.run(data)

Expand Down Expand Up @@ -86,16 +85,17 @@ async def generate_chart(prompt: dict, generator: Any) -> dict:
@observe(capture_input=False)
def post_process(
generate_chart: dict,
vega_schema: Dict[str, Any],
remove_data_from_chart_schema: bool,
preprocess_data: dict,
data_provided: bool,
post_processor: ChartGenerationPostProcessor,
) -> dict:
return post_processor.run(
generate_chart.get("replies"),
vega_schema,
preprocess_data["sample_data"],
remove_data_from_chart_schema,
preprocess_data["raw_data"]
if data_provided
else preprocess_data["sample_data"],
remove_data_from_chart_schema=remove_data_from_chart_schema,
)


Expand All @@ -122,20 +122,13 @@ def __init__(
template=chart_generation_user_prompt_template
),
"generator": llm_provider.get_generator(
system_prompt=chart_generation_system_prompt,
system_prompt=gen_chart_gen_system_prompt(),
generation_kwargs=CHART_GENERATION_MODEL_KWARGS,
),
"chart_data_preprocessor": ChartDataPreprocessor(),
"post_processor": ChartGenerationPostProcessor(),
}

with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f:
_vega_schema = orjson.loads(f.read())

self._configs = {
"vega_schema": _vega_schema,
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
Expand All @@ -148,6 +141,7 @@ async def run(
data: dict,
language: str,
remove_data_from_chart_schema: Optional[bool] = True,
data_provided: Optional[bool] = False,
) -> dict:
logger.info("Chart Generation pipeline is running...")
return await self._pipe.execute(
Expand All @@ -158,8 +152,8 @@ async def run(
"data": data,
"language": language,
"remove_data_from_chart_schema": remove_data_from_chart_schema,
"data_provided": data_provided,
**self._components,
**self._configs,
},
)

Expand Down
Loading
Loading