Skip to content

Commit 2714e7c

Browse files
authored
feat(extra): prompt tuning (#79)
1 parent 11a7b21 commit 2714e7c

24 files changed

+407
-10
lines changed

benchmarks/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ This folder contains scripts that produce reproducible timings and evaluation me
44

55
## Setup environment
66

7-
Before installing any package, make sure you have Python 3.8 or higher installed on your machine. From the root directory of the project, install the dependencies:
7+
Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies:
88

99
```bash
10-
pip install -e '.[benchmarks]'
10+
pip install -e '.[dev]'
1111
```
1212

1313
## Benchmark list

extra/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Extra
2+
3+
This folder contains scripts for researching stuff related to dbally. Links are provided where descriptions exist:
4+
5+
- [`Prompt tuning`](prompt_tuning/README.md)
6+
7+
## Setup environment
8+
9+
Before installing any package, make sure you have Python 3.9 or higher installed on your machine. From the root directory of the project, install the dependencies:
10+
11+
```bash
12+
pip install -e '.[dev]'
13+
```

extra/prompt_tuning/README.md

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Prompt tuning
2+
3+
This folder contains scripts for prompt tuning and evaluation. Prompts (programs) used in dbally:
4+
5+
- `FILTERING_ASSESSOR` - assesses whether a question requires filtering.
6+
7+
All evaluations are run on a dev split of the [BIRD](https://bird-bench.github.io/) dataset. For now, one configuration is available to run the suite against the `superhero` database.
8+
9+
## Usage
10+
11+
Run evalution of filtering assessor baseline on the `superhero` database with `gpt-3.5-turbo`:
12+
13+
```bash
14+
python evaluate.py program=filtering-assessor-baseline
15+
```
16+
17+
Test multiple programs:
18+
19+
```bash
20+
python evaluate.py --multirun program=filtering-assessor-baseline,filtering-assessor-cot
21+
```
22+
23+
Compare prompt performance on multiple LLMs:
24+
25+
```bash
26+
python evaluate.py --multirun program=filtering-assessor-baseline llm=gpt-3.5-turbo,claude-3.5-sonnet
27+
```
28+
29+
### Log to Neptune
30+
31+
Before running the evaluation with Neptune, configure the following environment variables:
32+
33+
```bash
34+
export NEPTUNE_API_TOKEN="API_TOKEN"
35+
export NEPTUNE_PROJECT="WORKSPACE_NAME/PROJECT_NAME"
36+
```
37+
38+
Export evaluation results to Neptune:
39+
40+
```bash
41+
python evaluate.py neptune=True
42+
```
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
defaults:
2+
- data: superhero
3+
- llm: gpt-3.5-turbo
4+
- program: filtering-assessor-baseline
5+
- _self_
6+
7+
neptune: False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
path: "micpst/bird-iql"
2+
split: "dev"
3+
db_ids: ["superhero"]
4+
difficulties: ["simple", "moderate", "challenging"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: claude-3-haiku-20240307
2+
provider: Claude
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: claude-3-opus-20240229
2+
provider: Claude
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: claude-3-5-sonnet-20240620
2+
provider: Claude
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: gpt-3.5-turbo
2+
provider: OpenAI
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: gpt-4-turbo
2+
provider: OpenAI
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model_name: gpt-4o
2+
provider: OpenAI
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
type: FILTERING_ASSESSOR
2+
name: FilteringAssessorBaseline
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
type: FILTERING_ASSESSOR
2+
name: FilteringAssessorCoT

extra/prompt_tuning/evaluate.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import asyncio
2+
import logging
3+
from enum import Enum
4+
from pathlib import Path
5+
6+
import dspy
7+
import hydra
8+
import neptune
9+
from dspy.evaluate import Evaluate
10+
from neptune.utils import stringify_unsupported
11+
from omegaconf import DictConfig
12+
from tuning.loaders import IQLGenerationDataLoader
13+
from tuning.metrics import filtering_assess_acc
14+
from tuning.programs import PROGRAMS
15+
from tuning.utils import save, serialize_results
16+
17+
logging.getLogger("httpx").setLevel(logging.ERROR)
18+
logging.getLogger("anthropic").setLevel(logging.ERROR)
19+
log = logging.getLogger(__name__)
20+
21+
22+
class EvaluationType(Enum):
23+
"""
24+
Enum representing the evaluation type.
25+
"""
26+
27+
FILTERING_ASSESSOR = "FILTERING_ASSESSOR"
28+
29+
30+
EVALUATION_DATALOADERS = {
31+
EvaluationType.FILTERING_ASSESSOR.value: IQLGenerationDataLoader,
32+
}
33+
34+
EVALUATION_METRICS = {
35+
EvaluationType.FILTERING_ASSESSOR.value: filtering_assess_acc,
36+
}
37+
38+
39+
async def evaluate(config: DictConfig) -> None:
40+
"""
41+
Function running evaluation for all datasets and evaluation tasks defined in hydra config.
42+
43+
Args:
44+
config: Hydra configuration.
45+
"""
46+
log.info("Starting evaluation: %s", config.program.name)
47+
48+
dataloader = EVALUATION_DATALOADERS[config.program.type](config)
49+
metric = EVALUATION_METRICS[config.program.type]
50+
program = PROGRAMS[config.program.name]()
51+
52+
dataset = await dataloader.load()
53+
54+
lm = dspy.__dict__[config.llm.provider](model=config.llm.model_name)
55+
dspy.settings.configure(lm=lm)
56+
57+
evaluator = Evaluate(
58+
devset=dataset,
59+
metric=metric,
60+
num_threads=32,
61+
display_progress=True,
62+
return_outputs=True,
63+
)
64+
metric, results = evaluator(program)
65+
66+
log.info("Evaluation finished. Saving results...")
67+
68+
output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
69+
results_file = output_dir / "results.json"
70+
save(results_file, results=serialize_results(results))
71+
72+
log.info("Evaluation results saved under directory: %s", output_dir)
73+
74+
if config.neptune:
75+
run = neptune.init_run()
76+
run["sys/tags"].add(
77+
[
78+
config.program.type,
79+
config.program.name,
80+
*config.data.db_ids,
81+
*config.data.difficulties,
82+
]
83+
)
84+
run["config"] = stringify_unsupported(config)
85+
run["evaluation/metrics/ACC"] = stringify_unsupported(metric)
86+
run["evaluation/results.json"].upload(results_file.as_posix())
87+
88+
89+
@hydra.main(config_path="config", config_name="config", version_base="3.2")
90+
def main(config: DictConfig) -> None:
91+
"""
92+
Function running evaluation for all datasets and evaluation tasks defined in hydra config.
93+
94+
Args:
95+
config: Hydra configuration.
96+
"""
97+
asyncio.run(evaluate(config))
98+
99+
100+
if __name__ == "__main__":
101+
main() # pylint: disable=no-value-for-parameter

extra/prompt_tuning/tuning/__init__.py

Whitespace-only changes.

extra/prompt_tuning/tuning/loaders.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, Iterable, List
3+
4+
import dspy.datasets
5+
from dspy import Example
6+
7+
8+
class DataLoader(ABC):
9+
"""
10+
Data loader.
11+
"""
12+
13+
def __init__(self, config: Dict) -> None:
14+
self.config = config
15+
16+
@abstractmethod
17+
async def load(self) -> Iterable:
18+
"""
19+
Load the data.
20+
21+
Returns:
22+
The loaded data.
23+
"""
24+
25+
26+
class HuggingFaceDataLoader(DataLoader):
27+
"""
28+
Hugging Face data loader.
29+
"""
30+
31+
async def load(self) -> List[Example]:
32+
"""
33+
Load the data from Hugging Face.
34+
35+
Returns:
36+
The loaded data.
37+
"""
38+
dataloader = dspy.datasets.DataLoader()
39+
dataset = dataloader.from_huggingface(
40+
dataset_name=self.config.data.path, split=self.config.data.split, input_keys=("question",)
41+
)
42+
return [
43+
data
44+
for data in dataset
45+
if data["question"]
46+
if (
47+
data["db_id"] in self.config.data.db_ids
48+
if self.config.data.db_ids
49+
else True and data["difficulty"] in self.config.data.difficulties
50+
if self.config.data.difficulties
51+
else True
52+
)
53+
]
54+
55+
56+
class IQLGenerationDataLoader(HuggingFaceDataLoader):
57+
"""
58+
Data loader for IQL generation evaluation.
59+
"""
60+
61+
async def load(self) -> List[Example]:
62+
"""
63+
Load the data from Hugging Face and filter out samples without views.
64+
65+
Returns:
66+
The loaded data.
67+
"""
68+
dataset = await super().load()
69+
return [data for data in dataset if data["view_name"]]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .iql import filtering_assess_acc
2+
3+
__all__ = ["filtering_assess_acc"]
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Dict
2+
3+
from dspy import Prediction
4+
5+
6+
def filtering_assess_acc(gold: Dict, pred: Prediction) -> bool:
7+
"""
8+
IQL decision metric.
9+
10+
Args:
11+
gold: The ground truth data point.
12+
pred: The prediction.
13+
14+
Returns:
15+
The decision metric.
16+
"""
17+
return ((gold["iql_filters"] is None and not gold["iql_filters_unsupported"]) and not pred.decision) or (
18+
(gold["iql_filters"] is not None or gold["iql_filters_unsupported"]) and pred.decision
19+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .iql import FilteringAssessorBaseline, FilteringAssessorCoT
2+
3+
PROGRAMS = {
4+
FilteringAssessorBaseline.__name__: FilteringAssessorBaseline,
5+
FilteringAssessorCoT.__name__: FilteringAssessorCoT,
6+
}
7+
8+
__all__ = ["PROGRAMS", "FilteringAssessorBaseline", "FilteringAssessorCoT"]
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dspy import ChainOfThought, Module, Predict, Prediction
2+
3+
from ..signatures.iql import CheckQuestionFiltering
4+
5+
6+
class FilteringAssessorBaseline(Module):
7+
"""
8+
Program that assesses whether a question requires filtering.
9+
"""
10+
11+
def __init__(self) -> None:
12+
super().__init__()
13+
self.decide = Predict(CheckQuestionFiltering)
14+
15+
def forward(self, question: str) -> Prediction:
16+
"""
17+
Assess whether a question requires filtering.
18+
19+
Args:
20+
question: The question to assess.
21+
22+
Returns:
23+
The prediction.
24+
"""
25+
decision = self.decide(question=question).decision
26+
return Prediction(decision=decision.lower() == "true")
27+
28+
29+
class FilteringAssessorCoT(Module):
30+
"""
31+
Program that assesses whether a question requires filtering.
32+
"""
33+
34+
def __init__(self) -> None:
35+
super().__init__()
36+
self.decide = ChainOfThought(CheckQuestionFiltering)
37+
38+
def forward(self, question: str) -> Prediction:
39+
"""
40+
Assess whether a question requires filtering.
41+
42+
Args:
43+
question: The question to assess.
44+
45+
Returns:
46+
The prediction.
47+
"""
48+
decision = self.decide(question=question).decision
49+
return Prediction(decision=decision.lower() == "true")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .iql import CheckQuestionFiltering
2+
3+
__all__ = ["CheckQuestionFiltering"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from dspy import InputField, OutputField, Signature
2+
3+
4+
class CheckQuestionFiltering(Signature):
5+
"""
6+
Given a question, determine whether the answer requires initial data filtering in order to compute it.
7+
Initial data filtering is a process in which the result set is reduced to only include the rows that
8+
meet certain criteria specified in the question.
9+
"""
10+
11+
question = InputField(
12+
prefix="Question: ",
13+
)
14+
decision = OutputField(
15+
prefix="Decision: ",
16+
desc=(
17+
"indicates whether the answer to the question requires initial data filtering. "
18+
"(Respond with True or False)"
19+
),
20+
)

0 commit comments

Comments
 (0)