Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(benchmarks): evaluation pipeline #71

Merged
merged 38 commits into from
Aug 6, 2024
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
553f111
rename benchmark to benchmarks
micpst Jul 3, 2024
b1b84a5
add README
micpst Jul 3, 2024
b19adc4
Merge branch 'main' into mp/refactor-benchmarks
micpst Jul 4, 2024
d499c42
mv everything to sql benchmark
micpst Jul 5, 2024
d5ce691
add docs for sql benchmarks
micpst Jul 5, 2024
940ce4e
update benchmark list
micpst Jul 5, 2024
5871032
fix imports
micpst Jul 5, 2024
4807cfc
rename bench script
micpst Jul 5, 2024
51330ed
add hf integration
micpst Jul 10, 2024
db2207c
add setup docs
micpst Jul 10, 2024
507b75c
eval for iql
micpst Jul 11, 2024
6344cf5
update README
micpst Jul 12, 2024
18022d5
brand new world
micpst Jul 12, 2024
c4242b6
great refactor
micpst Jul 15, 2024
abcbca5
update .gitignore and README
micpst Jul 15, 2024
bba8298
looks good
micpst Jul 15, 2024
7b29163
merge main
micpst Jul 16, 2024
5beac0e
update benchmarks
micpst Jul 16, 2024
460aa0f
fix + cleanup
micpst Jul 16, 2024
70d3e74
make EventTracker optional
micpst Jul 16, 2024
351fb0d
fix eval results saving
micpst Jul 16, 2024
6430576
add more metrics
micpst Jul 16, 2024
c91851f
almost done
micpst Jul 17, 2024
11545a1
update views
micpst Jul 19, 2024
c4005fb
update eval
micpst Jul 19, 2024
b8226c7
update pipeline for aggregations
micpst Jul 22, 2024
053eb61
fix pylint
micpst Jul 22, 2024
ed25531
final pipeline
micpst Jul 23, 2024
cd63a38
fix ci
micpst Jul 23, 2024
8c0f265
Merge branch 'main' into mp/refactor-benchmarks
micpst Jul 23, 2024
1c8fb18
add tests + update README
micpst Jul 24, 2024
1f0214a
add tests for ex
micpst Jul 24, 2024
ef0b0be
refactor
micpst Jul 30, 2024
c6781eb
add iql gen exception
micpst Jul 30, 2024
6ae80d7
move to separate file
micpst Jul 30, 2024
f34f791
Merge branch 'mp/iql-gen-exception' into mp/refactor-benchmarks
micpst Jul 30, 2024
d1b9857
update docs
micpst Jul 30, 2024
cb70ee9
add view names for eval results
micpst Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -55,6 +55,7 @@ licenses.txt
**/dist/
**/checkpoints/
**/outputs/
**/multirun/

# Other env files
.python-version
@@ -74,7 +75,6 @@ coverage.xml

# dotenv
.env
src/dbally_benchmark/.env

# coverage and pytest reports
coverage.xml
@@ -87,8 +87,8 @@ cmake-build-*/
**/.terraform.lock.hcl
**/.terraform

# experiments results
experiments/
# benchmarks
benchmarks/sql/data/

# mkdocs generated files
site/
22 changes: 0 additions & 22 deletions benchmark/dbally_benchmark/config.py

This file was deleted.

27 changes: 0 additions & 27 deletions benchmark/dbally_benchmark/constants.py

This file was deleted.

63 changes: 0 additions & 63 deletions benchmark/dbally_benchmark/dataset/bird_dataset.py

This file was deleted.

157 changes: 0 additions & 157 deletions benchmark/dbally_benchmark/e2e_benchmark.py

This file was deleted.

41 changes: 0 additions & 41 deletions benchmark/dbally_benchmark/evaluate.py

This file was deleted.

10 changes: 0 additions & 10 deletions benchmark/dbally_benchmark/experiment_config/config.yaml

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

13 changes: 0 additions & 13 deletions benchmark/dbally_benchmark/iql/iql_result.py

This file was deleted.

68 changes: 0 additions & 68 deletions benchmark/dbally_benchmark/iql/method_call_visitor.py

This file was deleted.

133 changes: 0 additions & 133 deletions benchmark/dbally_benchmark/iql/metrics.py

This file was deleted.

165 changes: 0 additions & 165 deletions benchmark/dbally_benchmark/iql_benchmark.py

This file was deleted.

10 changes: 0 additions & 10 deletions benchmark/dbally_benchmark/paths.py

This file was deleted.

268 changes: 0 additions & 268 deletions benchmark/dbally_benchmark/text2sql/metrics.py

This file was deleted.

19 changes: 0 additions & 19 deletions benchmark/dbally_benchmark/text2sql/prompt_template.py

This file was deleted.

10 changes: 0 additions & 10 deletions benchmark/dbally_benchmark/text2sql/text2sql_result.py

This file was deleted.

151 changes: 0 additions & 151 deletions benchmark/dbally_benchmark/text2sql_benchmark.py

This file was deleted.

73 changes: 0 additions & 73 deletions benchmark/dbally_benchmark/utils.py

This file was deleted.

227 changes: 0 additions & 227 deletions benchmark/dbally_benchmark/views/superhero.py

This file was deleted.

66 changes: 0 additions & 66 deletions benchmark/tests/unit/test_iql_metrics.py

This file was deleted.

26 changes: 0 additions & 26 deletions benchmark/tests/unit/test_main_evaluate.py

This file was deleted.

17 changes: 0 additions & 17 deletions benchmark/tests/unit/test_method_call_visitor.py

This file was deleted.

17 changes: 17 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# db-ally benchmarks

This folder contains scripts that produce reproducible timings and evaluation metrics of various db-ally features.

## Setup environment

Before installing any package, make sure you have Python 3.8 or higher installed on your machine. From the root directory of the project, install the dependencies:

```bash
pip install -e '.[benchmarks]'
```

## Benchmark list

Please refer to each subfolder to discover each benchmark suite. Links are provided where descriptions exist:

- [`SQL`](sql/README.md)
71 changes: 71 additions & 0 deletions benchmarks/sql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SQL benchmarks

This folder contains benchmarks for querying SQL databases with db-ally. This suite evaluates the following components:

- `COLLECTION` - measures correctness of SQL queries generated by the collection.
- `IQL_VIEW` - measures correctness of SQL queries generated by the structured views.
- `SQL_VIEW` - measures correctness of SQL queries generated by the freeform views.

All benchmarks are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database.

## Run benchmarks

### Usage

Before starting, download the `superhero.sqlite` database file from [BIRD](https://bird-bench.github.io/), change its extension to `*.db` and place it in the `data/` folder.

Run the whole suite on the `superhero` database with `gpt-3.5-turbo`:

```bash
python bench.py --multirun setup=iql-view,sql-view,collection
```

Run on multiple databases:

```bash
python bench.py setup=sql-view setup/views/freeform@setup.views='[superhero,...]' data=bird
```

You can also run each evaluation separately or in subgroups:

```bash
python bench.py setup=iql-view
python bench.py --multirun setup=iql-view,sql-view
```

Compare IQL/SQL generation performance on multiple LLMs:

```bash
python bench.py --multirun setup=iql-view setup/llm=gpt-3.5-turbo,claude-3.5-sonnet
python bench.py --multirun setup=sql-view setup/llm=gpt-3.5-turbo,claude-3.5-sonnet
```

For the `collection` setup, you need to specify models for both the view selection and the IQL generation step:

```bash
python bench.py --multirun \
setup=collection \
setup/llm@setup.selector_llm=gpt-3.5-turbo,claude-3.5-sonnet \
setup/llm@setup.generator_llm=gpt-3.5-turbo,claude-3.5-sonnet
```

### Log to Neptune

Before running the suite with Neptune, configure the following environment variables:

```bash
export NEPTUNE_API_TOKEN="API_TOKEN"
export NEPTUNE_PROJECT="WORKSPACE_NAME/PROJECT_NAME"
```

Export evaluation results to Neptune:

```bash
python bench.py setup=iql-view neptune=True
```

## Run tests

```bash
python -m pytest
```
151 changes: 151 additions & 0 deletions benchmarks/sql/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import asyncio
import logging
from enum import Enum
from pathlib import Path

import hydra
import neptune
from bench.evaluator import Evaluator
from bench.loaders import CollectionDataLoader, IQLViewDataLoader, SQLViewDataLoader
from bench.metrics import (
ExecutionAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
IQLFiltersPrecision,
IQLFiltersRecall,
MetricSet,
SQLExactMatch,
ViewSelectionAccuracy,
ViewSelectionPrecision,
ViewSelectionRecall,
)
from bench.pipelines import CollectionEvaluationPipeline, IQLViewEvaluationPipeline, SQLViewEvaluationPipeline
from bench.utils import save
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig

logging.getLogger("LiteLLM").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
log = logging.getLogger(__name__)


class EvaluationType(Enum):
"""
Enum representing the evaluation type.
"""

IQL = "IQL_VIEW"
SQL = "SQL_VIEW"
E2E = "COLLECTION"


EVALUATION_DATALOADERS = {
EvaluationType.IQL.value: IQLViewDataLoader,
EvaluationType.SQL.value: SQLViewDataLoader,
EvaluationType.E2E.value: CollectionDataLoader,
}

EVALUATION_PIPELINES = {
EvaluationType.IQL.value: IQLViewEvaluationPipeline,
EvaluationType.SQL.value: SQLViewEvaluationPipeline,
EvaluationType.E2E.value: CollectionEvaluationPipeline,
}

EVALUATION_METRICS = {
EvaluationType.IQL.value: MetricSet(
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
IQLFiltersParseability,
IQLFiltersCorrectness,
ExecutionAccuracy,
),
EvaluationType.SQL.value: MetricSet(
SQLExactMatch,
ExecutionAccuracy,
),
EvaluationType.E2E.value: MetricSet(
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
IQLFiltersParseability,
IQLFiltersCorrectness,
ViewSelectionAccuracy,
ViewSelectionPrecision,
ViewSelectionRecall,
SQLExactMatch,
ExecutionAccuracy,
),
}


async def bench(config: DictConfig) -> None:
"""
Function running evaluation for all datasets and evaluation tasks defined in hydra config.
Args:
config: Hydra configuration.
"""
log.info("Starting evaluation: %s", config.setup.name)

dataloader = EVALUATION_DATALOADERS[config.setup.name](config)
pipeline = EVALUATION_PIPELINES[config.setup.name](config)
metrics = EVALUATION_METRICS[config.setup.name](config)

evaluator = Evaluator(config.setup.name)
results = await evaluator.compute(
pipeline=pipeline,
dataloader=dataloader,
metrics=metrics,
)

log.info("Evaluation finished. Saving results...")

output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
metrics_file = output_dir / "metrics.json"
results_file = output_dir / "results.json"

save(metrics_file, metrics=results["metrics"], time_perf=results["time_perf"])
save(results_file, results=results["results"])

log.info("Evaluation results saved under directory: %s", output_dir)

if config.neptune:
run = neptune.init_run()
run["sys/tags"].add(
[
config.setup.name,
*config.data.db_ids,
*config.data.difficulties,
]
)
run["config"] = stringify_unsupported(config)
run["evaluation/metrics"] = stringify_unsupported(results["metrics"])
run["evaluation/time_perf"] = stringify_unsupported(results["time_perf"])
run["evaluation/metrics.json"].upload(metrics_file.as_posix())
run["evaluation/results.json"].upload(results_file.as_posix())


@hydra.main(config_path="config", config_name="config", version_base="3.2")
def main(config: DictConfig) -> None:
"""
Function running evaluation for all datasets and evaluation tasks defined in hydra config.
Args:
config: Hydra configuration.
"""
asyncio.run(bench(config))


if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
File renamed without changes.
122 changes: 122 additions & 0 deletions benchmarks/sql/bench/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import time
from dataclasses import asdict
from typing import Any, Callable, Dict, List, Tuple

from datasets import Dataset
from tqdm.asyncio import tqdm

from .loaders import DataLoader
from .metrics.base import MetricSet
from .pipelines import EvaluationPipeline, EvaluationResult


class Evaluator:
"""
Evaluator class.
"""

def __init__(self, task: str) -> None:
"""
Constructs the evaluator.
Args:
task: The task for the evaluator.
"""
self.task = task

async def compute(
self,
pipeline: Callable,
dataloader: DataLoader,
metrics: MetricSet,
) -> Dict[str, Any]:
"""
Compute the evaluation results for the given pipeline and data.
Args:
pipeline: The pipeline to be evaluated.
dataloader: The dataloader to load the data.
metrics: The metrics to be computed.
Returns:
The evaluation results.
"""
dataset = await dataloader.load()
results, perf_results = await self._call_pipeline(pipeline, dataset)
computed_metrics = self._compute_metrics(metrics, results)
results = self._results_processor(results)

result = {}
result.update(perf_results)
result.update(computed_metrics)
result.update(results)
return result

async def _call_pipeline(
self,
pipe: EvaluationPipeline,
dataset: Dataset,
) -> Tuple[List[EvaluationResult], Dict[str, Any]]:
"""
Call the pipeline with the given data.
Args:
pipe: The pipeline to be called.
data: The evaluation data.
Returns:
The evaluation results and performance metrics.
"""
start_time = time.perf_counter()
pipe_outputs = await tqdm.gather(*[pipe(data) for data in dataset], desc="Evaluation")
end_time = time.perf_counter()
return pipe_outputs, self._compute_time_perf(start_time, end_time, len(pipe_outputs))

def _results_processor(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Process the results.
Args:
results: The evaluation results.
Returns:
The processed results.
"""
return {"results": [asdict(result) for result in results]}

def _compute_metrics(self, metrics: MetricSet, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Compute a metric using the given inputs.
Args:
metrics: The metrics to be computed.
results: The evaluation results.
Returns:
The computed metric.
"""
return {"metrics": metrics.compute(results)}

def _compute_time_perf(self, start_time: float, end_time: float, num_samples: int) -> Dict[str, Any]:
"""
Compute the performance metrics.
Args:
start_time: The start time.
end_time: The end time.
num_samples: The number of samples.
Returns:
The performance metrics.
"""
latency = end_time - start_time
throughput = num_samples / latency
latency_sample = 1.0 / throughput

return {
"time_perf": {
"total_time_in_seconds": latency,
"samples_per_second": throughput,
"latency_in_seconds": latency_sample,
},
}
96 changes: 96 additions & 0 deletions benchmarks/sql/bench/loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from abc import ABC, abstractmethod
from typing import Dict, Iterable

from datasets import Dataset, load_dataset


class DataLoader(ABC):
"""
Data loader.
"""

def __init__(self, config: Dict) -> None:
self.config = config

@abstractmethod
async def load(self) -> Iterable:
"""
Load the data.
Returns:
The loaded data.
"""


class HuggingFaceDataLoader(DataLoader):
"""
Hugging Face data loader.
"""

async def load(self) -> Dataset:
"""
Load the data from Hugging Face.
Returns:
The loaded data.
"""
return load_dataset(
path=self.config.data.path,
split=self.config.data.split,
)


class IQLViewDataLoader(HuggingFaceDataLoader):
"""
Data loader for IQL view evaluation.
"""

async def load(self) -> Dataset:
"""
Load the data from Hugging Face and filter out samples without views.
Returns:
The loaded data.
"""
dataset = await super().load()
return dataset.filter(
lambda x: x["db_id"] in self.config.data.db_ids
and x["difficulty"] in self.config.data.difficulties
and x["view_name"] is not None
)


class SQLViewDataLoader(HuggingFaceDataLoader):
"""
Data loader for SQL view evaluation.
"""

async def load(self) -> Dataset:
"""
Load the data from Hugging Face.
Returns:
The loaded data.
"""
dataset = await super().load()
return dataset.filter(
lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties
)


class CollectionDataLoader(HuggingFaceDataLoader):
"""
Data loader for collection evaluation.
"""

async def load(self) -> Dataset:
"""
Load the data from Hugging Face.
Returns:
The loaded data.
"""
dataset = await super().load()
return dataset.filter(
lambda x: x["db_id"] in self.config.data.db_ids and x["difficulty"] in self.config.data.difficulties
)
31 changes: 31 additions & 0 deletions benchmarks/sql/bench/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .base import Metric, MetricSet
from .iql import (
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
IQLFiltersPrecision,
IQLFiltersRecall,
)
from .selector import ViewSelectionAccuracy, ViewSelectionPrecision, ViewSelectionRecall
from .sql import ExecutionAccuracy, SQLExactMatch

__all__ = [
"Metric",
"MetricSet",
"FilteringAccuracy",
"FilteringPrecision",
"FilteringRecall",
"IQLFiltersAccuracy",
"IQLFiltersPrecision",
"IQLFiltersRecall",
"IQLFiltersParseability",
"IQLFiltersCorrectness",
"SQLExactMatch",
"ViewSelectionAccuracy",
"ViewSelectionPrecision",
"ViewSelectionRecall",
"ExecutionAccuracy",
]
74 changes: 74 additions & 0 deletions benchmarks/sql/bench/metrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type

from typing_extensions import Self

from ..pipelines import EvaluationResult


class Metric(ABC):
"""
Base class for metrics.
"""

def __init__(self, config: Optional[Dict] = None) -> None:
"""
Initializes the metric.
Args:
config: The metric configuration.
"""
self.config = config or {}

@abstractmethod
def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Compute the metric.
Args:
results: The evaluation results.
Returns:
The computed metric.
"""


class MetricSet:
"""
Represents a set of metrics.
"""

def __init__(self, *metrics: List[Type[Metric]]) -> None:
"""
Initializes the metric set.
Args:
metrics: The metrics.
"""
self._metrics = metrics
self.metrics: List[Metric] = []

def __call__(self, config: Dict) -> Self:
"""
Initializes the metrics.
Args:
config: The configuration for the metrics.
Returns:
The initialized metric set.
"""
self.metrics = [metric(config) for metric in self._metrics]
return self

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Compute the metrics.
Args:
results: The evaluation results.
Returns:
The computed metrics.
"""
return {name: value for metric in self.metrics for name, value in metric.compute(results).items()}
286 changes: 286 additions & 0 deletions benchmarks/sql/bench/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
from typing import Any, Dict, List

from ..pipelines import EvaluationResult
from .base import Metric


class FilteringAccuracy(Metric):
"""
Filtering accuracy is proportion of correct decisions (to filter or not) out of all decisions made.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the filtering accuracy.
Args:
results: List of evaluation results.
Returns:
Filtering accuracy.
"""
results = [result for result in results if result.reference.iql and result.prediction.iql]
return {
"DM/FLT/ACC": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported
for result in results
)
/ len(results)
if results
else None
)
}


class FilteringPrecision(Metric):
"""
Filtering precision is proportion of correct decisions to filter out of all decisions to filter.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the filtering precision.
Args:
results: List of evaluation results.
Returns:
Filtering precision.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (result.prediction.iql.filters.source or result.prediction.iql.filters.unsupported)
]
return {
"DM/FLT/PRECISION": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported
for result in results
)
/ len(results)
if results
else None
)
}


class FilteringRecall(Metric):
"""
Filtering recall is proportion of correct decisions to filter out of all cases where filtering
should have been applied.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the filtering recall.
Args:
results: List of evaluation results.
Returns:
Filtering recall.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (result.reference.iql.filters.source or result.reference.iql.filters.unsupported)
]
return {
"DM/FLT/RECALL": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersAccuracy(Metric):
"""
IQL filters accuracy is proportion of correct IQL generations and unsupported query identifications out
of all attempts.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters accuracy.
Args:
results: List of evaluation results.
Returns:
IQL filters accuracy.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
or result.reference.iql.filters.unsupported
and result.prediction.iql.filters.source
or result.prediction.iql.filters.unsupported
)
]
return {
"IQL/FLT/ACC": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersPrecision(Metric):
"""
IQL filters precision is proportion of correct IQL generations out of all IQL generation attempts.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters precision.
Args:
results: List of evaluation results.
Returns:
IQL filters precision.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
or result.reference.iql.filters.unsupported
and result.prediction.iql.filters.source
)
]
return {
"IQL/FLT/PRECISION": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersRecall(Metric):
"""
IQL filters recall is proportion of correct IQL generations out of all cases where an IQL
should have been generated.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters recall.
Args:
results: List of evaluation results.
Returns:
IQL filters recall.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
and result.prediction.iql.filters.source
or result.prediction.iql.filters.unsupported
)
]
return {
"IQL/FLT/RECALL": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersParseability(Metric):
"""
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters parseability.
Args:
results: List of evaluation results.
Returns:
IQl filters parseability.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (result.reference.iql.filters and result.prediction.iql.filters)
and (result.reference.iql.filters.source and result.prediction.iql.filters.source)
]
return {
"IQL/FLT/PARSEABILITY": (
sum(result.prediction.iql.filters.valid for result in results) / len(results) if results else None
)
}


class IQLFiltersCorrectness(Metric):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters correctness.
Args:
results: List of evaluation results.
Returns:
IQL filters correctness.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
and result.prediction.iql.filters.source
and result.prediction.iql.filters.valid
)
]
return {
"IQL/FLT/CORRECTNESS": (
sum(result.prediction.iql.filters.source == result.reference.iql.filters.source for result in results)
/ len(results)
if results
else None
)
}
85 changes: 85 additions & 0 deletions benchmarks/sql/bench/metrics/selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, Dict, List

from ..pipelines import EvaluationResult
from .base import Metric


class ViewSelectionAccuracy(Metric):
"""
View selection accuracy is the proportion of correct view selections out of all view selection attempts.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the view selection accuracy.
Args:
results: List of evaluation results.
Returns:
View selection accuracy.
"""
return {
"VIEW/ACC": (
sum(result.reference.view_name == result.prediction.view_name for result in results) / len(results)
if results
else None
)
}


class ViewSelectionPrecision(Metric):
"""
View selection precision is proportion of correct view selections out of all cases where a view was selected.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the view selection precision.
Args:
results: List of evaluation results.
Returns:
View selection precision.
"""
results = [result for result in results if result.prediction.view_name]
return {
"VIEW/PRECISION": (
sum(result.prediction.view_name == result.reference.view_name for result in results) / len(results)
if results
else None
)
}


class ViewSelectionRecall(Metric):
"""
View selection recall is proportion of correct view selections out of all cases where a view should have
been selected.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the view selection recall.
Args:
results: List of evaluation results.
Returns:
View selection recall.
"""
results = [
result
for result in results
if result.prediction.view_name is None
and result.reference.view_name
or result.prediction.view_name == result.reference.view_name
]
return {
"VIEW/RECALL": (
sum(result.prediction.view_name == result.reference.view_name for result in results) / len(results)
if results
else None
)
}
151 changes: 151 additions & 0 deletions benchmarks/sql/bench/metrics/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import time
from typing import Any, Dict, List

import pandas as pd
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError

from ..pipelines import EvaluationResult
from .base import Metric


class SQLExactMatch(Metric):
"""
Exact match ratio i.e. the proportion of examples in the evaluation set for which
the predicted SQL is identical to the ground truth SQL.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the exact match ratio.
Args:
results: List of evaluation results.
Returns:
The exact match ratio.
"""
return {
"SQL/EM": (
sum(result.prediction.sql == result.reference.sql for result in results) / len(results)
if results
else 0.0
)
}


class _DBMixin:
"""
Mixin class for database operations.
"""

def __init__(self, config: Dict, *args: Any, **kwargs: Any) -> None:
super().__init__(config, *args, **kwargs)
self.dbs = {db: create_engine(f"sqlite:///data/{db}.db") for db in config.data.db_ids}

def _execute_query(self, query: str, db_id: str) -> List[Dict[str, Any]]:
"""
Execute the given query on the database.
Args:
query: The query to be executed.
Returns:
The query results.
"""
with self.dbs[db_id].connect() as connection:
rows = connection.execute(text(query)).fetchall()
return [dict(row._mapping) for row in rows] # pylint: disable=protected-access

def _avarage_execution_time(self, query: str, db_id: str, n: int = 100) -> float:
"""
Execute the given query on the database n times and return the average execution time.
Args:
query: The query to be executed.
n: The number of times to execute the query.
Returns:
The average execution time.
"""
total_time = 0
for _ in range(n):
start_time = time.perf_counter()
self._execute_query(query, db_id)
total_time += time.perf_counter() - start_time
return total_time / n


class ExecutionAccuracy(_DBMixin, Metric):
"""
Execution accuracy score i.e. the proportion of examples in the evaluation set for
which the executed results of both the predicted and ground-truth SQLs are identical.
Valid efficiency score measures the efficiency of valid SQLs generated
by models. More details about this metric can be found here: https://arxiv.org/pdf/2305.03111.pdf.
"""

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Calculates the execution accuracy score and valid efficiency score.
Args:
results: List of evaluation results.
Returns:
Execution accuracy score and valid efficiency score.
"""
accurate_results = [result for result in results if self._execution_accuracy(result)]
return {
"EX": len(accurate_results) / len(results) if results else None,
"VES": sum(
(
self._avarage_execution_time(result.reference.sql, result.db_id)
/ self._avarage_execution_time(result.prediction.sql, result.db_id)
)
** 0.5
for result in accurate_results
)
/ len(results)
if results
else None,
}

def _execution_accuracy(self, result: EvaluationResult) -> bool:
"""
Checks if the execution results of both the predicted and ground-truth SQLs are identical.
Args:
result: Evaluation result.
Returns:
True if the execution results are identical, False otherwise.
"""
if result.prediction.sql is None:
return False

try:
ref_results = self._execute_query(result.reference.sql, result.db_id)
pred_results = self._execute_query(result.prediction.sql, result.db_id)
except SQLAlchemyError:
return False

reference = pd.DataFrame(ref_results)
prediction = pd.DataFrame(pred_results)

# If filtering works correctly, the number of rows will be the same
# TODO: Sometimes a different number of rows is okay, e.g. if df has aggregated values that are expanded in gt
if reference.shape[0] != prediction.shape[0]:
return False

# Returned view may have the same columns, or more columns than the ground truth
if not reference.columns.isin(prediction.columns).all():
return False

# Check if dataframe equality, disregarding indexing and order
# commented out way is also ok but slower. Leaving it here just in case
# return df_gt.merge(df[df_gt.columns], how='outer', on=df_gt.columns.tolist(),
# indicator='indicator').indicator.drop_duplicates().values.tolist() == ['both']
prediction = prediction[reference.columns].sort_values(by=reference.columns.tolist()).reset_index(drop=True)
reference = reference.sort_values(by=reference.columns.tolist()).reset_index(drop=True)
return prediction.equals(reference)
13 changes: 13 additions & 0 deletions benchmarks/sql/bench/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .base import EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
from .collection import CollectionEvaluationPipeline
from .view import IQLViewEvaluationPipeline, SQLViewEvaluationPipeline

__all__ = [
"CollectionEvaluationPipeline",
"EvaluationPipeline",
"EvaluationResult",
"ExecutionResult",
"IQLResult",
"IQLViewEvaluationPipeline",
"SQLViewEvaluationPipeline",
]
84 changes: 84 additions & 0 deletions benchmarks/sql/bench/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional

from dbally.llms.base import LLM
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM


@dataclass
class IQL:
"""
Represents the IQL.
"""

source: Optional[str] = None
unsupported: bool = False
valid: bool = True


@dataclass
class IQLResult:
"""
Represents the result of an IQL query execution.
"""

filters: IQL
aggregation: IQL
context: bool = False


@dataclass
class ExecutionResult:
"""
Represents the result of a single query execution.
"""

view_name: Optional[str] = None
iql: Optional[IQLResult] = None
sql: Optional[str] = None


@dataclass
class EvaluationResult:
"""
Represents the result of a single evaluation.
"""

db_id: str
question: str
reference: ExecutionResult
prediction: ExecutionResult


class EvaluationPipeline(ABC):
"""
Collection evaluation pipeline.
"""

def get_llm(self, config: Dict) -> LLM:
"""
Returns the LLM based on the configuration.
Args:
config: The LLM configuration.
Returns:
The LLM object.
"""
if config.model_name.startswith("local/"):
return LocalLLM(config.model_name.split("/", 1)[1])
return LiteLLM(config.model_name)

@abstractmethod
async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
"""
Runs the evaluation pipeline.
Args:
data: The evaluation data.
Returns:
The evaluation result.
"""
140 changes: 140 additions & 0 deletions benchmarks/sql/bench/pipelines/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Any, Dict

from sqlalchemy import create_engine

import dbally
from dbally.collection.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.view_selection.llm_view_selector import LLMViewSelector
from dbally.views.exceptions import IQLGenerationError

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult


class CollectionEvaluationPipeline(EvaluationPipeline):
"""
Collection evaluation pipeline.
"""

def __init__(self, config: Dict) -> None:
"""
Constructs the pipeline for evaluating collection predictions.
Args:
config: The configuration for the pipeline.
"""
self.collection = self.get_collection(config.setup)

def get_collection(self, config: Dict) -> Collection:
"""
Sets up the collection based on the configuration.
Args:
config: The collection configuration.
Returns:
The collection.
"""
generator_llm = self.get_llm(config.generator_llm)
selector_llm = self.get_llm(config.selector_llm)
view_selector = LLMViewSelector(selector_llm)

collection = dbally.create_collection(
name=config.name,
llm=generator_llm,
view_selector=view_selector,
)
collection.n_retries = 0

for db_name, view_names in config.views.items():
db = create_engine(f"sqlite:///data/{db_name}.db")
for view_name in view_names:
view_cls = VIEWS_REGISTRY[view_name]
collection.add(view_cls, lambda: view_cls(db)) # pylint: disable=cell-var-from-loop

return collection

async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
"""
Runs the collection evaluation pipeline.
Args:
data: The evaluation data.
Returns:
The evaluation result.
"""
try:
result = await self.collection.ask(
question=data["question"],
dry_run=True,
return_natural_response=False,
)
except NoViewFoundError:
prediction = ExecutionResult(
view_name=None,
iql=None,
sql=None,
)
except IQLGenerationError as exc:
prediction = ExecutionResult(
view_name=exc.view_name,
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=result.view_name,
iql=IQLResult(
filters=IQL(
source=result.context.get("iql"),
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
),
sql=result.context.get("sql"),
)

reference = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=data["iql_filters"],
unsupported=data["iql_filters_unsupported"],
valid=True,
),
aggregation=IQL(
source=data["iql_aggregation"],
unsupported=data["iql_aggregation_unsupported"],
valid=True,
),
context=data["iql_context"],
),
sql=data["sql"],
)

return EvaluationResult(
db_id=data["db_id"],
question=data["question"],
reference=reference,
prediction=prediction,
)
215 changes: 215 additions & 0 deletions benchmarks/sql/bench/pipelines/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# pylint: disable=duplicate-code

from abc import ABC, abstractmethod
from typing import Any, Dict, Type

from sqlalchemy import create_engine

from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.views.exceptions import IQLGenerationError
from dbally.views.freeform.text2sql.view import BaseText2SQLView
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult


class ViewEvaluationPipeline(EvaluationPipeline, ABC):
"""
View evaluation pipeline.
"""

def __init__(self, config: Dict) -> None:
"""
Constructs the pipeline for evaluating IQL predictions.
Args:
config: The configuration for the pipeline.
"""
self.llm = self.get_llm(config.setup.llm)
self.dbs = self.get_dbs(config.setup)
self.views = self.get_views(config.setup)

def get_dbs(self, config: Dict) -> Dict:
"""
Returns the database object based on the database name.
Args:
config: The database configuration.
Returns:
The database object.
"""
return {db: create_engine(f"sqlite:///data/{db}.db") for db in config.views}

@abstractmethod
def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]:
"""
Creates the view classes mapping based on the configuration.
Args:
config: The views configuration.
Returns:
The view classes mapping.
"""


class IQLViewEvaluationPipeline(ViewEvaluationPipeline):
"""
IQL view evaluation pipeline.
"""

def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]:
"""
Creates the view classes mapping based on the configuration.
Args:
config: The views configuration.
Returns:
The view classes mapping.
"""
return {
view_name: VIEWS_REGISTRY[view_name] for view_names in config.views.values() for view_name in view_names
}

async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
"""
Runs the evaluation pipeline.
Args:
data: The evaluation data.
Returns:
The evaluation result.
"""
view = self.views[data["view_name"]](self.dbs[data["db_id"]])

try:
result = await view.ask(
query=data["question"],
llm=self.llm,
dry_run=True,
n_retries=0,
)
except IQLGenerationError as exc:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=result.context["iql"],
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
),
sql=result.context["sql"],
)

reference = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=data["iql_filters"],
unsupported=data["iql_filters_unsupported"],
valid=True,
),
aggregation=IQL(
source=data["iql_aggregation"],
unsupported=data["iql_aggregation_unsupported"],
valid=True,
),
context=data["iql_context"],
),
sql=data["sql"],
)

return EvaluationResult(
db_id=data["db_id"],
question=data["question"],
reference=reference,
prediction=prediction,
)


class SQLViewEvaluationPipeline(ViewEvaluationPipeline):
"""
SQL view evaluation pipeline.
"""

def get_views(self, config: Dict) -> Dict[str, Type[BaseText2SQLView]]:
"""
Creates the view classes mapping based on the configuration.
Args:
config: The views configuration.
Returns:
The view classes mapping.
"""
return {db_id: VIEWS_REGISTRY[view_name] for db_id, view_name in config.views.items()}

async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
"""
Runs the evaluation pipeline.
Args:
data: The evaluation data.
Returns:
The evaluation result.
"""
view = self.views[data["db_id"]](self.dbs[data["db_id"]])

try:
result = await view.ask(
query=data["question"],
llm=self.llm,
dry_run=True,
n_retries=0,
)
# TODO: Remove this broad exception handling once the Text2SQL view is fixed
except Exception: # pylint: disable=broad-except
prediction = ExecutionResult(
view_name=view.__class__.__name__,
)
else:
prediction = ExecutionResult(
view_name=view.__class__.__name__,
sql=result.context["sql"],
)

reference = ExecutionResult(
view_name=data["view_name"],
sql=data["sql"],
)

return EvaluationResult(
db_id=data["db_id"],
question=data["question"],
reference=reference,
prediction=prediction,
)
23 changes: 23 additions & 0 deletions benchmarks/sql/bench/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any


def save(file_path: Path, **data: Any) -> None:
"""
Save the data to a file. Add the current timestamp and Python version to the data.
Args:
file_path: The path to the file.
data: The data to be saved.
"""
current_time = datetime.now()

data["_timestamp"] = current_time.isoformat()
data["_python_version"] = sys.version
data["_interpreter_path"] = sys.executable

with open(file_path, "w", encoding="utf-8") as file:
json.dump(data, file, indent=4)
12 changes: 12 additions & 0 deletions benchmarks/sql/bench/views/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Dict, Type

from dbally.views.base import BaseView

from .freeform.superhero import SuperheroFreeformView
from .structured.superhero import PublisherView, SuperheroView

VIEWS_REGISTRY: Dict[str, Type[BaseView]] = {
PublisherView.__name__: PublisherView,
SuperheroView.__name__: SuperheroView,
SuperheroFreeformView.__name__: SuperheroFreeformView,
}
File renamed without changes.
Loading