diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index f0459fc13d..135ea6fa9d 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -2,7 +2,6 @@ import sys from typing import Any, Dict -import orjson from hamilton import base from hamilton.async_driver import AsyncDriver from haystack.components.builders.prompt_builder import PromptBuilder @@ -14,34 +13,36 @@ ChartDataPreprocessor, ChartGenerationPostProcessor, ChartGenerationResults, - chart_generation_instructions, ) -from src.web.v1.services.chart_adjustment import ChartAdjustmentOption logger = logging.getLogger("wren-ai-service") -chart_adjustment_system_prompt = f""" +def gen_chart_adjustment_system_prompt() -> str: + return """ ### TASK ### -You are a data analyst great at visualizing data using vega-lite! Given the user's question, SQL, sample data, sample column values, original vega-lite schema and adjustment options, -you need to re-generate vega-lite schema in JSON and provide suitable chart type. -Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data, sample column values, original vega-lite schema and adjustment options. +You are a data analyst great at generating data visualization using vega-lite! Given the user's question, SQL, sample data, sample column values, original vega-lite schema and adjustment command, +you need to think about the best chart type and generate corresponding vega-lite schema in JSON format. +Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data, sample column values, original vega-lite schema and adjustment command. -{chart_generation_instructions} -- If you think the adjustment options are not suitable for the data, you can return an empty string for the schema and chart type and give reasoning to explain why. +### INSTRUCTIONS ### + +- You need to generate the new vega-lite schema based on the adjustment command and the original vega-lite schema. +- If you think the adjustment command is not suitable for the data, you can return an empty string for the schema and chart type and give reasoning to explain why. +- If the user provides an image, you need to use the image as reference to generate a new chart schema that follows user's adjustment command. ### OUTPUT FORMAT ### -Please provide your chain of thought reasoning, chart type and the vega-lite schema in JSON format. +Please provide your chain of thought reasoning, and the vega-lite schema in JSON format. -{{ +{ "reasoning": , - "chart_type": "line" | "multi_line" | "bar" | "pie" | "grouped_bar" | "stacked_bar" | "area" | "", "chart_schema": -}} +} """ + chart_adjustment_user_prompt_template = """ ### INPUT ### Original Question: {{ query }} @@ -51,25 +52,7 @@ Sample Column Values: {{ sample_column_values }} Language: {{ language }} -Adjustment Options: -- Chart Type: {{ adjustment_option.chart_type }} -{% if adjustment_option.chart_type != "pie" %} -{% if adjustment_option.x_axis %} -- X Axis: {{ adjustment_option.x_axis }} -{% endif %} -{% if adjustment_option.y_axis %} -- Y Axis: {{ adjustment_option.y_axis }} -{% endif %} -{% endif %} -{% if adjustment_option.x_offset and adjustment_option.chart_type == "grouped_bar" %} -- X Offset: {{ adjustment_option.x_offset }} -{% endif %} -{% if adjustment_option.color and adjustment_option.chart_type != "area" %} -- Color: {{ adjustment_option.color }} -{% endif %} -{% if adjustment_option.theta and adjustment_option.chart_type == "pie" %} -- Theta: {{ adjustment_option.theta }} -{% endif %} +Adjustment Command: {{ adjustment_command }} Please think step by step """ @@ -78,7 +61,8 @@ ## Start of Pipeline @observe(capture_input=False) def preprocess_data( - data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor + data: Dict[str, Any], + chart_data_preprocessor: ChartDataPreprocessor, ) -> dict: return chart_data_preprocessor.run(data) @@ -87,7 +71,7 @@ def preprocess_data( def prompt( query: str, sql: str, - adjustment_option: ChartAdjustmentOption, + adjustment_command: str, chart_schema: dict, preprocess_data: dict, language: str, @@ -99,7 +83,7 @@ def prompt( return prompt_builder.run( query=query, sql=sql, - adjustment_option=adjustment_option, + adjustment_command=adjustment_command, chart_schema=chart_schema, sample_data=sample_data, sample_column_values=sample_column_values, @@ -108,21 +92,26 @@ def prompt( @observe(as_type="generation", capture_input=False) -async def generate_chart_adjustment(prompt: dict, generator: Any) -> dict: - return await generator(prompt=prompt.get("prompt")) +async def generate_chart_adjustment( + prompt: dict, image_url: str, generator: Any +) -> dict: + return await generator(prompt=prompt.get("prompt"), image_url=image_url) @observe(capture_input=False) def post_process( generate_chart_adjustment: dict, - vega_schema: Dict[str, Any], + remove_data_from_chart_schema: bool, preprocess_data: dict, + data_provided: bool, post_processor: ChartGenerationPostProcessor, ) -> dict: return post_processor.run( generate_chart_adjustment.get("replies"), - vega_schema, - preprocess_data["sample_data"], + preprocess_data["raw_data"] + if data_provided + else preprocess_data["sample_data"], + remove_data_from_chart_schema=remove_data_from_chart_schema, ) @@ -149,19 +138,13 @@ def __init__( template=chart_adjustment_user_prompt_template ), "generator": llm_provider.get_generator( - system_prompt=chart_adjustment_system_prompt, + system_prompt=gen_chart_adjustment_system_prompt(), generation_kwargs=CHART_ADJUSTMENT_MODEL_KWARGS, ), "chart_data_preprocessor": ChartDataPreprocessor(), "post_processor": ChartGenerationPostProcessor(), } - with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: - _vega_schema = orjson.loads(f.read()) - - self._configs = { - "vega_schema": _vega_schema, - } super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -171,10 +154,13 @@ async def run( self, query: str, sql: str, - adjustment_option: ChartAdjustmentOption, + adjustment_command: str, chart_schema: dict, data: dict, language: str, + remove_data_from_chart_schema: bool = True, + data_provided: bool = False, + image_url: str = "", ) -> dict: logger.info("Chart Adjustment pipeline is running...") @@ -183,12 +169,14 @@ async def run( inputs={ "query": query, "sql": sql, - "adjustment_option": adjustment_option, + "adjustment_command": adjustment_command, "chart_schema": chart_schema, "data": data, "language": language, + "remove_data_from_chart_schema": remove_data_from_chart_schema, + "data_provided": data_provided, + "image_url": image_url, **self._components, - **self._configs, }, ) @@ -201,7 +189,7 @@ async def run( "chart_adjustment", query="show me the dataset", sql="", - adjustment_option=ChartAdjustmentOption(), + adjustment_command="", chart_schema={}, # data={}, language="English", diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index 8cab43a49b..eb572b838f 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -2,7 +2,6 @@ import sys from typing import Any, Dict, Optional -import orjson from hamilton import base from hamilton.async_driver import AsyncDriver from haystack.components.builders.prompt_builder import PromptBuilder @@ -14,30 +13,29 @@ ChartDataPreprocessor, ChartGenerationPostProcessor, ChartGenerationResults, - chart_generation_instructions, ) logger = logging.getLogger("wren-ai-service") -chart_generation_system_prompt = f""" + +def gen_chart_gen_system_prompt() -> str: + return """ ### TASK ### -You are a data analyst great at visualizing data using vega-lite! Given the user's question, SQL, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type. +You are a data analyst great at generating data visualization using vega-lite! Given the user's question, SQL, sample data and sample column values, you need to think about the best chart type and generate correspondingvega-lite schema in JSON format. Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, SQL, sample data and sample column values. -{chart_generation_instructions} - ### OUTPUT FORMAT ### -Please provide your chain of thought reasoning, chart type and the vega-lite schema in JSON format. +Please provide your chain of thought reasoning, and the vega-lite schema in JSON format. -{{ +{ "reasoning": , - "chart_type": "line" | "multi_line" | "bar" | "pie" | "grouped_bar" | "stacked_bar" | "area" | "", "chart_schema": -}} +} """ + chart_generation_user_prompt_template = """ ### INPUT ### Question: {{ query }} @@ -53,7 +51,8 @@ ## Start of Pipeline @observe(capture_input=False) def preprocess_data( - data: Dict[str, Any], chart_data_preprocessor: ChartDataPreprocessor + data: Dict[str, Any], + chart_data_preprocessor: ChartDataPreprocessor, ) -> dict: return chart_data_preprocessor.run(data) @@ -86,16 +85,17 @@ async def generate_chart(prompt: dict, generator: Any) -> dict: @observe(capture_input=False) def post_process( generate_chart: dict, - vega_schema: Dict[str, Any], remove_data_from_chart_schema: bool, preprocess_data: dict, + data_provided: bool, post_processor: ChartGenerationPostProcessor, ) -> dict: return post_processor.run( generate_chart.get("replies"), - vega_schema, - preprocess_data["sample_data"], - remove_data_from_chart_schema, + preprocess_data["raw_data"] + if data_provided + else preprocess_data["sample_data"], + remove_data_from_chart_schema=remove_data_from_chart_schema, ) @@ -122,20 +122,13 @@ def __init__( template=chart_generation_user_prompt_template ), "generator": llm_provider.get_generator( - system_prompt=chart_generation_system_prompt, + system_prompt=gen_chart_gen_system_prompt(), generation_kwargs=CHART_GENERATION_MODEL_KWARGS, ), "chart_data_preprocessor": ChartDataPreprocessor(), "post_processor": ChartGenerationPostProcessor(), } - with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: - _vega_schema = orjson.loads(f.read()) - - self._configs = { - "vega_schema": _vega_schema, - } - super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -148,6 +141,7 @@ async def run( data: dict, language: str, remove_data_from_chart_schema: Optional[bool] = True, + data_provided: Optional[bool] = False, ) -> dict: logger.info("Chart Generation pipeline is running...") return await self._pipe.execute( @@ -158,8 +152,8 @@ async def run( "data": data, "language": language, "remove_data_from_chart_schema": remove_data_from_chart_schema, + "data_provided": data_provided, **self._components, - **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/utils/chart.py b/wren-ai-service/src/pipelines/generation/utils/chart.py index 5d06b949a9..576ac99607 100644 --- a/wren-ai-service/src/pipelines/generation/utils/chart.py +++ b/wren-ai-service/src/pipelines/generation/utils/chart.py @@ -1,247 +1,18 @@ import logging -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Optional import orjson import pandas as pd from haystack import component -from jsonschema import validate from jsonschema.exceptions import ValidationError -from pydantic import BaseModel, Field +from pydantic import BaseModel logger = logging.getLogger("wren-ai-service") -chart_generation_instructions = """ -### INSTRUCTIONS ### - -- Chart types: Bar chart, Line chart, Multi line chart, Area chart, Pie chart, Stacked bar chart, Grouped bar chart -- You can only use the chart types provided in the instructions -- Generated chart should answer the user's question and based on the semantics of the SQL query, and the sample data, sample column values are used to help you generate the suitable chart type -- If the sample data is not suitable for visualization, you must return an empty string for the schema and chart type -- If the sample data is empty, you must return an empty string for the schema and chart type -- The language for the chart and reasoning must be the same language provided by the user -- Please use the current time provided by the user to generate the chart -- In order to generate the grouped bar chart, you need to follow the given instructions: - - Disable Stacking: Add "stack": null to the y-encoding. - - Use xOffset for subcategories to group bars. - - Don't use "transform" section. -- In order to generate the pie chart, you need to follow the given instructions: - - Add {"type": "arc"} to the mark section. - - Add "theta" encoding to the encoding section. - - Add "color" encoding to the encoding section. - - Don't add "innerRadius" to the mark section. -- If the x-axis of the chart is a temporal field, the time unit should be the same as the question user asked. - - For yearly question, the time unit should be "year". - - For monthly question, the time unit should be "yearmonth". - - For weekly question, the time unit should be "yearmonthdate". - - For daily question, the time unit should be "yearmonthdate". - - Default time unit is "yearmonth". -- For each axis, generate the corresponding human-readable title based on the language provided by the user. -- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data. - -### GUIDELINES TO PLOT CHART ### - -1. Understanding Your Data Types -- Nominal (Categorical): Names or labels without a specific order (e.g., types of fruits, countries). -- Ordinal: Categorical data with a meaningful order but no fixed intervals (e.g., rankings, satisfaction levels). -- Quantitative: Numerical values representing counts or measurements (e.g., sales figures, temperatures). -- Temporal: Date or time data (e.g., timestamps, dates). -2. Chart Types and When to Use Them -- Bar Chart - - Use When: Comparing quantities across different categories. - - Data Requirements: - - One categorical variable (x-axis). - - One quantitative variable (y-axis). - - Example: Comparing sales numbers for different product categories. -- Grouped Bar Chart - - Use When: Comparing sub-categories within main categories. - - Data Requirements: - - Two categorical variables (x-axis grouped by one, color-coded by another). - - One quantitative variable (y-axis). - - Example: Sales numbers for different products across various regions. -- Line Chart - - Use When: Displaying trends over continuous data, especially time. - - Data Requirements: - - One temporal or ordinal variable (x-axis). - - One quantitative variable (y-axis). - - Example: Tracking monthly revenue over a year. -- Multi Line Chart - - Use When: Displaying trends over continuous data, especially time. - - Data Requirements: - - One temporal or ordinal variable (x-axis). - - Two or more quantitative variables (y-axis and color). - - Implementation Notes: - - Uses `transform` with `fold` to combine multiple metrics into a single series - - The folded metrics are distinguished using the color encoding - - Example: Tracking monthly click rate and read rate over a year. -- Area Chart - - Use When: Similar to line charts but emphasizing the volume of change over time. - - Data Requirements: - - Same as Line Chart. - - Example: Visualizing cumulative rainfall over months. -- Pie Chart - - Use When: Showing parts of a whole as percentages. - - Data Requirements: - - One categorical variable. - - One quantitative variable representing proportions. - - Example: Market share distribution among companies. -- Stacked Bar Chart - - Use When: Showing composition and comparison across categories. - - Data Requirements: Same as grouped bar chart. - - Example: Sales by region and product type. -- Guidelines for Selecting Chart Types - - Comparing Categories: - - Bar Chart: Best for simple comparisons across categories. - - Grouped Bar Chart: Use when you have sub-categories. - - Stacked Bar Chart: Use to show composition within categories. - - Showing Trends Over Time: - - Line Chart: Ideal for continuous data over time. - - Area Chart: Use when you want to emphasize volume or total value over time. - - Displaying Proportions: - - Pie Chart: Use for simple compositions at a single point in time. - - Stacked Bar Chart (100%): Use for comparing compositions across multiple categories. - -### EXAMPLES ### - -1. Bar Chart -- Sample Data: - [ - {"Region": "North", "Sales": 100}, - {"Region": "South", "Sales": 200}, - {"Region": "East", "Sales": 300}, - {"Region": "West", "Sales": 400} -] -- Chart Schema: -{ - "title": , - "mark": {"type": "bar"}, - "encoding": { - "x": {"field": "Region", "type": "nominal", "title": }, - "y": {"field": "Sales", "type": "quantitative", "title": }, - "color": {"field": "Region", "type": "nominal", "title": ""} - } -} -2. Line Chart -- Sample Data: -[ - {"Date": "2022-01-01", "Sales": 100}, - {"Date": "2022-01-02", "Sales": 200}, - {"Date": "2022-01-03", "Sales": 300}, - {"Date": "2022-01-04", "Sales": 400} -] -- Chart Schema: -{ - "title": , - "mark": {"type": "line"}, - "encoding": { - "x": {"field": "Date", "type": "temporal", "title": }, - "y": {"field": "Sales", "type": "quantitative", "title": } - } -} -3. Pie Chart -- Sample Data: -[ - {"Company": "Company A", "Market Share": 0.4}, - {"Company": "Company B", "Market Share": 0.3}, - {"Company": "Company C", "Market Share": 0.2}, - {"Company": "Company D", "Market Share": 0.1} -] -- Chart Schema: -{ - "title": , - "mark": {"type": "arc"}, - "encoding": { - "theta": {"field": "Market Share", "type": "quantitative"}, - "color": {"field": "Company", "type": "nominal", "title": } - } -} -4. Area Chart -- Sample Data: -[ - {"Date": "2022-01-01", "Sales": 100}, - {"Date": "2022-01-02", "Sales": 200}, - {"Date": "2022-01-03", "Sales": 300}, - {"Date": "2022-01-04", "Sales": 400} -] -- Chart Schema: -{ - "title": "", - "mark": {"type": "area"}, - "encoding": { - "x": {"field": "Date", "type": "temporal", "title": ""}, - "y": {"field": "Sales", "type": "quantitative", "title": ""} - } -} -5. Stacked Bar Chart -- Sample Data: -[ - {"Region": "North", "Product": "A", "Sales": 100}, - {"Region": "North", "Product": "B", "Sales": 150}, - {"Region": "South", "Product": "A", "Sales": 200}, - {"Region": "South", "Product": "B", "Sales": 250}, - {"Region": "East", "Product": "A", "Sales": 300}, - {"Region": "East", "Product": "B", "Sales": 350}, - {"Region": "West", "Product": "A", "Sales": 400}, - {"Region": "West", "Product": "B", "Sales": 450} -] -- Chart Schema: -{ - "title": "", - "mark": {"type": "bar"}, - "encoding": { - "x": {"field": "Region", "type": "nominal", "title": ""}, - "y": {"field": "Sales", "type": "quantitative", "title": "", "stack": "zero"}, - "color": {"field": "Product", "type": "nominal", "title": ""} - } -} -6. Grouped Bar Chart -- Sample Data: -[ - {"Region": "North", "Product": "A", "Sales": 100}, - {"Region": "North", "Product": "B", "Sales": 150}, - {"Region": "South", "Product": "A", "Sales": 200}, - {"Region": "South", "Product": "B", "Sales": 250}, - {"Region": "East", "Product": "A", "Sales": 300}, - {"Region": "East", "Product": "B", "Sales": 350}, - {"Region": "West", "Product": "A", "Sales": 400}, - {"Region": "West", "Product": "B", "Sales": 450} -] -- Chart Schema: -{ - "title": "", - "mark": {"type": "bar"}, - "encoding": { - "x": {"field": "Region", "type": "nominal", "title": ""}, - "y": {"field": "Sales", "type": "quantitative", "title": ""}, - "xOffset": {"field": "Product", "type": "nominal", "title": ""}, - "color": {"field": "Product", "type": "nominal", "title": ""} - } -} -7. Multi Line Chart -- Sample Data: -[ - {"Date": "2022-01-01", "readCount": 100, "clickCount": 10}, - {"Date": "2022-01-02", "readCount": 200, "clickCount": 30}, - {"Date": "2022-01-03", "readCount": 300, "clickCount": 20}, - {"Date": "2022-01-04", "readCount": 400, "clickCount": 40} -] -- Chart Schema: -{ - "title": , - "mark": {"type": "line"}, - "transform": [ - { - "fold": ["readCount", "clickCount"], - "as": ["Metric", "Value"] - } - ], - "encoding": { - "x": {"field": "Date", "type": "temporal", "title": }, - "y": {"field": "Value", "type": "quantitative", "title": }, - "color": {"field": "Metric", "type": "nominal", "title": } - } -} -""" +def load_custom_theme() -> Dict[str, Any]: + with open("src/pipelines/generation/utils/theme_powerbi.json", "r") as f: + return orjson.loads(f.read()) @component @@ -273,6 +44,7 @@ def run( sample_data = df.to_dict(orient="records") return { + "raw_data": df.to_dict(orient="records"), "sample_data": sample_data, "sample_column_values": sample_column_values, } @@ -286,7 +58,6 @@ class ChartGenerationPostProcessor: def run( self, replies: str, - vega_schema: Dict[str, Any], sample_data: list[dict], remove_data_from_chart_schema: Optional[bool] = True, ): @@ -303,8 +74,7 @@ def run( "$schema" ] = "https://vega.github.io/schema/vega-lite/v5.json" chart_schema["data"] = {"values": sample_data} - - validate(chart_schema, schema=vega_schema) + chart_schema["config"] = load_custom_theme() if remove_data_from_chart_schema: chart_schema["data"]["values"] = [] @@ -346,135 +116,14 @@ def run( } -class ChartSchema(BaseModel): - class ChartType(BaseModel): - type: Literal["bar", "line", "area", "arc"] - - class ChartEncoding(BaseModel): - field: str - type: Literal["ordinal", "quantitative", "nominal"] - title: str - - title: str - mark: ChartType - encoding: ChartEncoding - - -class TemporalChartEncoding(ChartSchema.ChartEncoding): - type: Literal["temporal"] = Field(default="temporal") - timeUnit: str = Field(default="yearmonth") - - -class LineChartSchema(ChartSchema): - class LineChartMark(BaseModel): - type: Literal["line"] = Field(default="line") - - class LineChartEncoding(BaseModel): - x: TemporalChartEncoding | ChartSchema.ChartEncoding - y: ChartSchema.ChartEncoding - color: ChartSchema.ChartEncoding - - mark: LineChartMark - encoding: LineChartEncoding - - -class MultiLineChartSchema(ChartSchema): - class MultiLineChartMark(BaseModel): - type: Literal["line"] = Field(default="line") - - class MultiLineChartTransform(BaseModel): - fold: list[str] - as_: list[str] = Field(alias="as") - - class MultiLineChartEncoding(BaseModel): - x: TemporalChartEncoding | ChartSchema.ChartEncoding - y: ChartSchema.ChartEncoding - color: ChartSchema.ChartEncoding - - mark: MultiLineChartMark - transform: list[MultiLineChartTransform] - encoding: MultiLineChartEncoding - - -class BarChartSchema(ChartSchema): - class BarChartMark(BaseModel): - type: Literal["bar"] = Field(default="bar") +def read_vega_lite_schema() -> Dict[str, Any]: + with open("src/pipelines/generation/utils/vega-lite-schema-v5.json", "r") as f: + vega_lite_schema = orjson.loads(f.read()) - class BarChartEncoding(BaseModel): - x: TemporalChartEncoding | ChartSchema.ChartEncoding - y: ChartSchema.ChartEncoding - color: ChartSchema.ChartEncoding - - mark: BarChartMark - encoding: BarChartEncoding - - -class GroupedBarChartSchema(ChartSchema): - class GroupedBarChartMark(BaseModel): - type: Literal["bar"] = Field(default="bar") - - class GroupedBarChartEncoding(BaseModel): - x: TemporalChartEncoding | ChartSchema.ChartEncoding - y: ChartSchema.ChartEncoding - xOffset: ChartSchema.ChartEncoding - color: ChartSchema.ChartEncoding - - mark: GroupedBarChartMark - encoding: GroupedBarChartEncoding - - -class StackedBarChartYEncoding(ChartSchema.ChartEncoding): - stack: Literal["zero"] = Field(default="zero") - - -class StackedBarChartSchema(ChartSchema): - class StackedBarChartMark(BaseModel): - type: Literal["bar"] = Field(default="bar") - - class StackedBarChartEncoding(BaseModel): - x: TemporalChartEncoding | ChartSchema.ChartEncoding - y: StackedBarChartYEncoding - color: ChartSchema.ChartEncoding - - mark: StackedBarChartMark - encoding: StackedBarChartEncoding - - -class PieChartSchema(ChartSchema): - class PieChartMark(BaseModel): - type: Literal["arc"] = Field(default="arc") - - class PieChartEncoding(BaseModel): - theta: ChartSchema.ChartEncoding - color: ChartSchema.ChartEncoding - - mark: PieChartMark - encoding: PieChartEncoding - - -class AreaChartSchema(ChartSchema): - class AreaChartMark(BaseModel): - type: Literal["area"] = Field(default="area") - - class AreaChartEncoding(BaseModel): - x: TemporalChartEncoding | ChartSchema.ChartEncoding - y: ChartSchema.ChartEncoding - - mark: AreaChartMark - encoding: AreaChartEncoding + return vega_lite_schema class ChartGenerationResults(BaseModel): reasoning: str - chart_type: Literal[ - "line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", "" - ] # empty string for no chart - chart_schema: ( - LineChartSchema - | MultiLineChartSchema - | BarChartSchema - | PieChartSchema - | GroupedBarChartSchema - | StackedBarChartSchema - | AreaChartSchema - ) + chart_schema: dict[str, Any] + chart_type: Optional[str] = "" # deprecated diff --git a/wren-ai-service/src/pipelines/generation/utils/theme_powerbi.json b/wren-ai-service/src/pipelines/generation/utils/theme_powerbi.json new file mode 100644 index 0000000000..e744baf335 --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/utils/theme_powerbi.json @@ -0,0 +1,128 @@ +{ + "view": { + "stroke": "transparent" + }, + "background": "transparent", + "font": "Segoe UI", + "header": { + "titleFont": "wf_standard-font, helvetica, arial, sans-serif", + "titleFontSize": 16, + "titleColor": "#252423", + "labelFont": "Segoe UI", + "labelFontSize": 13.333333333333334, + "labelColor": "#605E5C" + }, + "axis": { + "ticks": false, + "grid": false, + "domain": false, + "labelColor": "#605E5C", + "labelFontSize": 12, + "titleFont": "wf_standard-font, helvetica, arial, sans-serif", + "titleColor": "#252423", + "titleFontSize": 16, + "titleFontWeight": "normal" + }, + "axisQuantitative": { + "tickCount": 3, + "grid": true, + "gridColor": "#C8C6C4", + "gridDash": [ + 1, + 5 + ], + "labelFlush": false + }, + "axisBand": { + "tickExtra": true + }, + "axisX": { + "labelPadding": 5 + }, + "axisY": { + "labelPadding": 10 + }, + "bar": { + "fill": "#118DFF" + }, + "line": { + "stroke": "#118DFF", + "strokeWidth": 3, + "strokeCap": "round", + "strokeJoin": "round" + }, + "text": { + "font": "Segoe UI", + "fontSize": 12, + "fill": "#605E5C" + }, + "arc": { + "fill": "#118DFF" + }, + "area": { + "fill": "#118DFF", + "line": true, + "opacity": 0.6 + }, + "path": { + "stroke": "#118DFF" + }, + "rect": { + "fill": "#118DFF" + }, + "point": { + "fill": "#118DFF", + "filled": true, + "size": 75 + }, + "shape": { + "stroke": "#118DFF" + }, + "symbol": { + "fill": "#118DFF", + "strokeWidth": 1.5, + "size": 50 + }, + "legend": { + "titleFont": "Segoe UI", + "titleFontWeight": "bold", + "titleColor": "#605E5C", + "labelFont": "Segoe UI", + "labelFontSize": 13.333333333333334, + "labelColor": "#605E5C", + "symbolType": "circle", + "symbolSize": 75 + }, + "range": { + "category": [ + "#118DFF", + "#12239E", + "#E66C37", + "#6B007B", + "#E044A7", + "#744EC2", + "#D9B300", + "#D64550" + ], + "diverging": [ + "#DEEFFF", + "#118DFF" + ], + "heatmap": [ + "#DEEFFF", + "#118DFF" + ], + "ordinal": [ + "#DEEFFF", + "#c7e4ff", + "#b0d9ff", + "#9aceff", + "#83c3ff", + "#6cb9ff", + "#55aeff", + "#3fa3ff", + "#2898ff", + "#118DFF" + ] + } +} \ No newline at end of file diff --git a/wren-ai-service/src/providers/llm/__init__.py b/wren-ai-service/src/providers/llm/__init__.py index 4025c6f29d..9656092c80 100644 --- a/wren-ai-service/src/providers/llm/__init__.py +++ b/wren-ai-service/src/providers/llm/__init__.py @@ -1,11 +1,134 @@ import logging -from typing import Any, List - -from haystack.dataclasses import ChatMessage, StreamingChunk +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional logger = logging.getLogger("wren-ai-service") +class ChatRole(str, Enum): + """Enumeration representing the roles within a chat.""" + + ASSISTANT = "assistant" + USER = "user" + SYSTEM = "system" + FUNCTION = "function" + + +@dataclass +class ChatMessage: + """ + Represents a message in a LLM chat conversation. + + :param content: The text content of the message. + :param role: The role of the entity sending the message. + :param name: The name of the function being called (only applicable for role FUNCTION). + :param meta: Additional metadata associated with the message. + """ + + content: str + role: ChatRole + name: Optional[str] + image_url: Optional[str] + meta: Dict[str, Any] = field(default_factory=dict, hash=False) + + def is_from(self, role: ChatRole) -> bool: + """ + Check if the message is from a specific role. + + :param role: The role to check against. + :returns: True if the message is from the specified role, False otherwise. + """ + return self.role == role + + @classmethod + def from_assistant( + cls, content: str, meta: Optional[Dict[str, Any]] = None + ) -> "ChatMessage": + """ + Create a message from the assistant. + + :param content: The text content of the message. + :param meta: Additional metadata associated with the message. + :returns: A new ChatMessage instance. + """ + return cls( + content, ChatRole.ASSISTANT, name=None, image_url=None, meta=meta or {} + ) + + @classmethod + def from_user(cls, content: str, image_url: Optional[str] = None) -> "ChatMessage": + """ + Create a message from the user. + + :param content: The text content of the message. + :returns: A new ChatMessage instance. + """ + return cls(content, ChatRole.USER, name=None, image_url=image_url) + + @classmethod + def from_system(cls, content: str) -> "ChatMessage": + """ + Create a message from the system. + + :param content: The text content of the message. + :returns: A new ChatMessage instance. + """ + return cls(content, ChatRole.SYSTEM, name=None, image_url=None) + + @classmethod + def from_function(cls, content: str, name: str) -> "ChatMessage": + """ + Create a message from a function call. + + :param content: The text content of the message. + :param name: The name of the function being called. + :returns: A new ChatMessage instance. + """ + return cls(content, ChatRole.FUNCTION, name=name, image_url=None, meta=None) + + def to_dict(self) -> Dict[str, Any]: + """ + Converts ChatMessage into a dictionary. + + :returns: + Serialized version of the object. + """ + data = asdict(self) + data["role"] = self.role.value + + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": + """ + Creates a new ChatMessage object from a dictionary. + + :param data: + The dictionary to build the ChatMessage object. + :returns: + The created object. + """ + data["role"] = ChatRole(data["role"]) + + return cls(**data) + + +@dataclass +class StreamingChunk: + """ + The StreamingChunk class encapsulates a segment of streamed content along with associated metadata. + + This structure facilitates the handling and processing of streamed data in a systematic manner. + + :param content: The content of the message chunk as a string. + :param meta: A dictionary containing metadata related to the message chunk. + """ + + content: str + meta: Dict[str, Any] = field(default_factory=dict, hash=False) + + def build_message(completion: Any, choice: Any) -> ChatMessage: """ Converts the response from the OpenAI API to a ChatMessage. @@ -96,3 +219,34 @@ def build_chunk(chunk: Any) -> StreamingChunk: } ) return chunk_message + + +def convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: + """ + Convert a message to the format expected by OpenAI's Chat API. + + See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. + + :returns: A dictionary with the following key: + - `role` + - `content` + - `name` (optional) + """ + openai_msg = {"role": message.role.value} + + if message.content and message.image_url: + openai_msg["content"] = [ + {"type": "text", "text": message.content}, + {"type": "image_url", "image_url": {"url": message.image_url}}, + ] + elif message.content: + openai_msg["content"] = message.content + elif message.image_url: + openai_msg["content"] = [ + {"type": "image_url", "image_url": {"url": message.image_url}} + ] + + if message.name: + openai_msg["name"] = message.name + + return openai_msg diff --git a/wren-ai-service/src/providers/llm/litellm.py b/wren-ai-service/src/providers/llm/litellm.py index 5c5122f7a3..935e7677c7 100644 --- a/wren-ai-service/src/providers/llm/litellm.py +++ b/wren-ai-service/src/providers/llm/litellm.py @@ -1,22 +1,21 @@ import os from typing import Any, Callable, Dict, List, Optional, Union -from haystack.components.generators.openai_utils import ( - _convert_message_to_openai_format, -) -from haystack.dataclasses import ChatMessage, StreamingChunk from litellm import acompletion from litellm.types.utils import ModelResponse from src.core.provider import LLMProvider from src.providers.llm import ( + ChatMessage, + StreamingChunk, build_chunk, build_message, check_finish_reason, connect_chunks, + convert_message_to_openai_format, ) from src.providers.loader import provider -from src.utils import remove_trailing_slash, extract_braces_content +from src.utils import extract_braces_content, remove_trailing_slash @provider("litellm_llm") @@ -46,15 +45,19 @@ def get_generator( generation_kwargs: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): - combined_generation_kwargs = {**(generation_kwargs or {}), **self._model_kwargs} + combined_generation_kwargs = { + **(generation_kwargs or {}), + **(self._model_kwargs or {}), + } async def _run( prompt: str, + image_url: Optional[str] = None, history_messages: Optional[List[ChatMessage]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, query_id: Optional[str] = None, ): - message = ChatMessage.from_user(prompt) + message = ChatMessage.from_user(prompt, image_url) if system_prompt: messages = [ChatMessage.from_system(system_prompt)] if history_messages: @@ -67,7 +70,7 @@ async def _run( messages = [message] openai_formatted_messages = [ - _convert_message_to_openai_format(message) for message in messages + convert_message_to_openai_format(message) for message in messages ] generation_kwargs = { @@ -113,7 +116,9 @@ async def _run( check_finish_reason(response) return { - "replies": [extract_braces_content(message.content) for message in completions], + "replies": [ + extract_braces_content(message.content) for message in completions + ], "meta": [message.meta for message in completions], } diff --git a/wren-ai-service/src/web/v1/services/chart.py b/wren-ai-service/src/web/v1/services/chart.py index 772f3aae48..224ff16902 100644 --- a/wren-ai-service/src/web/v1/services/chart.py +++ b/wren-ai-service/src/web/v1/services/chart.py @@ -66,10 +66,8 @@ class ChartResultRequest(BaseModel): class ChartResult(BaseModel): reasoning: str - chart_type: Literal[ - "line", "bar", "pie", "grouped_bar", "stacked_bar", "area", "multi_line", "" - ] # empty string for no chart - chart_schema: dict + chart_type: Optional[str] = "" + chart_schema: Optional[dict] = None class ChartResultResponse(BaseModel): @@ -116,6 +114,7 @@ async def chart( } try: + data_provided = False query_id = chart_request.query_id if not chart_request.data: @@ -132,6 +131,7 @@ async def chart( )["execute_sql"]["results"] else: sql_data = chart_request.data + data_provided = True self._chart_results[query_id] = ChartResultResponse( status="generating", @@ -144,6 +144,7 @@ async def chart( data=sql_data, language=chart_request.configurations.language, remove_data_from_chart_schema=chart_request.remove_data_from_chart_schema, + data_provided=data_provided, ) chart_result = chart_generation_result["post_process"]["results"] diff --git a/wren-ai-service/src/web/v1/services/chart_adjustment.py b/wren-ai-service/src/web/v1/services/chart_adjustment.py index dff22fe010..429324f0ca 100644 --- a/wren-ai-service/src/web/v1/services/chart_adjustment.py +++ b/wren-ai-service/src/web/v1/services/chart_adjustment.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional from cachetools import TTLCache from langfuse.decorators import observe @@ -28,11 +28,15 @@ class ChartAdjustmentRequest(BaseModel): _query_id: str | None = None query: str sql: str - adjustment_option: ChartAdjustmentOption + adjustment_command: str + data: Optional[Dict[str, Any]] = None + image_url: Optional[str] = None chart_schema: dict project_id: Optional[str] = None thread_id: Optional[str] = None configurations: Optional[Configuration] = Configuration() + remove_data_from_chart_schema: Optional[bool] = True + adjustment_option: Optional[ChartAdjustmentOption] = None # deprecated @property def query_id(self) -> str: @@ -77,9 +81,7 @@ class ChartAdjustmentResultRequest(BaseModel): class ChartAdjustmentResult(BaseModel): reasoning: str - chart_type: Literal[ - "line", "bar", "pie", "grouped_bar", "stacked_bar", "area", "multi_line", "" - ] # empty string for no chart + chart_type: Optional[str] = "" chart_schema: dict @@ -91,6 +93,7 @@ class ChartAdjustmentResultResponse(BaseModel): error: Optional[ChartAdjustmentError] = None trace_id: Optional[str] = None + class ChartAdjustmentService: def __init__( self, @@ -128,6 +131,7 @@ async def chart_adjustment( } try: + data_provided = False query_id = chart_adjustment_request.query_id self._chart_adjustment_results[query_id] = ChartAdjustmentResultResponse( @@ -135,12 +139,16 @@ async def chart_adjustment( trace_id=trace_id, ) - sql_data = ( - await self._pipelines["sql_executor"].run( - sql=chart_adjustment_request.sql, - project_id=chart_adjustment_request.project_id, - ) - )["execute_sql"]["results"] + if not chart_adjustment_request.data: + sql_data = ( + await self._pipelines["sql_executor"].run( + sql=chart_adjustment_request.sql, + project_id=chart_adjustment_request.project_id, + ) + )["execute_sql"]["results"] + else: + sql_data = chart_adjustment_request.data + data_provided = True self._chart_adjustment_results[query_id] = ChartAdjustmentResultResponse( status="generating", @@ -150,10 +158,13 @@ async def chart_adjustment( chart_adjustment_result = await self._pipelines["chart_adjustment"].run( query=chart_adjustment_request.query, sql=chart_adjustment_request.sql, - adjustment_option=chart_adjustment_request.adjustment_option, + adjustment_command=chart_adjustment_request.adjustment_command, chart_schema=chart_adjustment_request.chart_schema, data=sql_data, + remove_data_from_chart_schema=chart_adjustment_request.remove_data_from_chart_schema, language=chart_adjustment_request.configurations.language, + data_provided=data_provided, + image_url=chart_adjustment_request.image_url, ) chart_result = chart_adjustment_result["post_process"]["results"]