From bc12f1cf666c7fca0ee29129eaab182964e14293 Mon Sep 17 00:00:00 2001 From: "siqi.an" Date: Tue, 16 Sep 2025 17:28:53 +0800 Subject: [PATCH] add qps&recall line --- .../components/check_results/stPageConfig.py | 2 +- .../frontend/components/qps_recall/charts.py | 117 ++++++++++++++++++ .../frontend/components/qps_recall/data.py | 58 +++++++++ vectordb_bench/frontend/pages/qps_recall.py | 59 +++++++++ 4 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 vectordb_bench/frontend/components/qps_recall/charts.py create mode 100644 vectordb_bench/frontend/components/qps_recall/data.py create mode 100644 vectordb_bench/frontend/pages/qps_recall.py diff --git a/vectordb_bench/frontend/components/check_results/stPageConfig.py b/vectordb_bench/frontend/components/check_results/stPageConfig.py index 9521be285..8ddbf8b3d 100644 --- a/vectordb_bench/frontend/components/check_results/stPageConfig.py +++ b/vectordb_bench/frontend/components/check_results/stPageConfig.py @@ -5,7 +5,7 @@ def initResultsPageConfig(st): st.set_page_config( page_title=PAGE_TITLE, page_icon=FAVICON, - # layout="wide", + layout="wide", # initial_sidebar_state="collapsed", ) diff --git a/vectordb_bench/frontend/components/qps_recall/charts.py b/vectordb_bench/frontend/components/qps_recall/charts.py new file mode 100644 index 000000000..b2a180c5a --- /dev/null +++ b/vectordb_bench/frontend/components/qps_recall/charts.py @@ -0,0 +1,117 @@ +from vectordb_bench.frontend.components.check_results.expanderStyle import ( + initMainExpanderStyle, +) +from vectordb_bench.metric import metric_order, isLowerIsBetterMetric, metric_unit_map +from vectordb_bench.frontend.config.styles import * +import plotly.express as px +import pandas as pd +import plotly.graph_objects as go +import matplotlib.pyplot as plt + + +def drawCharts(st, allData, caseNames: list[str]): + initMainExpanderStyle(st) + for caseName in caseNames: + chartContainer = st.expander(caseName, True) + data = [data for data in allData if data["case_name"] == caseName] + drawChart(data, chartContainer, key_prefix=caseName) + + +def drawChart(data, st, key_prefix: str): + metricsSet = set() + for d in data: + metricsSet = metricsSet.union(d["metricsSet"]) + showlineMetrics = [metric for metric in metric_order[:2] if metric in metricsSet] + + if showlineMetrics: + metric = showlineMetrics[0] + key = f"{key_prefix}-{metric}" + drawlinechart(st, data, metric, key=key) + + +def drawBestperformance(data, y, group): + all_filter_points = [] + data = pd.DataFrame(data) + grouped = data.groupby(group) + for name, group_df in grouped: + filter_points = [] + current_start = 0 + for _ in range(len(group_df)): + if current_start >= len(group_df): + break + max_index = group_df[y].iloc[current_start:].idxmax() + filter_points.append(group_df.loc[max_index]) + + current_start = group_df.index.get_loc(max_index) + 1 + all_filter_points.extend(filter_points) + + all_filter_df = pd.DataFrame(all_filter_points) + remaining_df = data[~data.isin(all_filter_df).any(axis=1)] + new_data = all_filter_df.to_dict(orient="records") + remain_data = remaining_df.to_dict(orient="records") + return new_data, remain_data + + +def drawlinechart(st, data: list[object], metric, key: str): + unit = metric_unit_map.get(metric, "") + minV = min([d.get(metric, 0) for d in data]) + maxV = max([d.get(metric, 0) for d in data]) + padding = maxV - minV + rangeV = [ + minV - padding * 0.1, + maxV + padding * 0.1, + ] + x = "recall" + xrange = [0.8, 1.01] + y = "qps" + yrange = rangeV + data.sort(key=lambda a: a[x]) + group = "db_name" + new_data, new_remain_data = drawBestperformance(data, y, group) + unique_db_names = list(set(item["db_name"] for item in new_data + new_remain_data)) + + colors = plt.cm.get_cmap("tab10", len(unique_db_names)) + + color_map = { + db: f"rgb({int(colors(i)[0] * 255)}, {int(colors(i)[1] * 255)}, {int(colors(i)[2] * 255)})" + for i, db in enumerate(unique_db_names) + } + + fig = go.Figure() + + new_data_df = pd.DataFrame(new_data) + + for db in unique_db_names: + db_data = new_data_df[new_data_df["db_name"] == db] + fig.add_trace( + go.Scatter( + x=db_data["recall"], + y=db_data["qps"], + mode="lines+markers", + name=db, + line=dict(color=color_map[db]), + marker=dict(color=color_map[db]), + showlegend=True, + ) + ) + + for item in new_remain_data: + fig.add_trace( + go.Scatter( + x=[item["recall"]], + y=[item["qps"]], + mode="markers", + name=item["db_name"], + marker=dict(color=color_map[item["db_name"]]), + showlegend=False, + ) + ) + + fig.update_xaxes(range=xrange) + fig.update_yaxes(range=yrange) + fig.update_traces(textposition="bottom right", texttemplate="%{y:,.4~r}" + unit) + fig.update_layout( + margin=dict(l=0, r=0, t=40, b=0, pad=8), + legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""), + ) + st.plotly_chart(fig, use_container_width=True, key=key) diff --git a/vectordb_bench/frontend/components/qps_recall/data.py b/vectordb_bench/frontend/components/qps_recall/data.py new file mode 100644 index 000000000..b4cbcb1b5 --- /dev/null +++ b/vectordb_bench/frontend/components/qps_recall/data.py @@ -0,0 +1,58 @@ +from collections import defaultdict +from dataclasses import asdict +from vectordb_bench.backend.filter import FilterOp +from vectordb_bench.frontend.components.check_results.data import getFilterTasks +from vectordb_bench.frontend.components.check_results.filters import getShowDbsAndCases, getshownResults +from vectordb_bench.models import CaseResult, ResultLabel, TestResult + + +def getshownData(st, results: list[TestResult], filter_type: FilterOp = FilterOp.NonFilter, **kwargs): + # hide the nav + st.markdown( + "", + unsafe_allow_html=True, + ) + st.header("Filters") + shownResults = getshownResults(st, results, **kwargs) + showDBNames, showCaseNames = getShowDbsAndCases(st, shownResults, filter_type) + shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames) + return shownData, failedTasks, showCaseNames + + +def getChartData( + tasks: list[CaseResult], + dbNames: list[str], + caseNames: list[str], +): + filterTasks = getFilterTasks(tasks, dbNames, caseNames) + failedTasks = defaultdict(lambda: defaultdict(str)) + nonemergedTasks = [] + for task in filterTasks: + db_name = task.task_config.db_name + db = task.task_config.db.value + db_label = task.task_config.db_config.db_label or "" + version = task.task_config.db_config.version or "" + case = task.task_config.case_config.case + case_name = case.name + dataset_name = case.dataset.data.full_name + filter_rate = case.filter_rate + metrics = asdict(task.metrics) + label = task.label + if label == ResultLabel.NORMAL: + nonemergedTasks.append( + { + "db_name": db_name, + "db": db, + "db_label": db_label, + "dataset_name": dataset_name, + "filter_rate": filter_rate, + "version": version, + "case_name": case_name, + "metricsSet": set(metrics.keys()), + **metrics, + } + ) + else: + failedTasks[case_name][db_name] = label + + return nonemergedTasks, failedTasks diff --git a/vectordb_bench/frontend/pages/qps_recall.py b/vectordb_bench/frontend/pages/qps_recall.py new file mode 100644 index 000000000..62d4a74f6 --- /dev/null +++ b/vectordb_bench/frontend/pages/qps_recall.py @@ -0,0 +1,59 @@ +import streamlit as st +from vectordb_bench.frontend.components.check_results.footer import footer +from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon +from vectordb_bench.frontend.components.check_results.nav import ( + NavToQuriesPerDollar, + NavToRunTest, + NavToPages, +) +from vectordb_bench.frontend.components.qps_recall.charts import drawCharts +from vectordb_bench.frontend.components.qps_recall.data import getshownData +from vectordb_bench.frontend.components.get_results.saveAsImage import getResults + +from vectordb_bench.frontend.config.styles import FAVICON +from vectordb_bench.interface import benchmark_runner + + +def main(): + # set page config + st.set_page_config( + page_title="Label Filter", + page_icon=FAVICON, + layout="wide", + # initial_sidebar_state="collapsed", + ) + + # header + drawHeaderIcon(st) + + # navigate + NavToPages(st) + + allResults = benchmark_runner.get_results() + + st.title("Vector Database Benchmark (Qps & Recall)") + + # results selector and filter + resultSelectorContainer = st.sidebar.container() + shownData, failedTasks, showCaseNames = getshownData(resultSelectorContainer, allResults) + + resultSelectorContainer.divider() + + # nav + navContainer = st.sidebar.container() + NavToRunTest(navContainer) + NavToQuriesPerDollar(navContainer) + + # save or share + resultesContainer = st.sidebar.container() + getResults(resultesContainer, "vectordb_bench") + + # charts + drawCharts(st, shownData, showCaseNames) + + # footer + footer(st.container()) + + +if __name__ == "__main__": + main()