Skip to content

Commit 6a05d29

Browse files
Fix bugs in retriever sdg notebook (#522) (#560)
* Signed-off by [email protected] * Signed-off by [email protected] * fixed qa bug 5008113, Signed-off by [email protected] * bug fixes for generator, Signed-off by [email protected] * fixed precommit, Signed-off by [email protected] * fixed filters, Signed-off by [email protected] * fixed all issues, Signed-off by [email protected] * fixed bug with document id, Signed-off by [email protected] * check if filtering pipeline is present, Signed-off by [email protected] * fixed notebook, Signed-off by [email protected] * added functionality to filter pre-generated datasets, Signed-off by [email protected] * separated generation & filtering pipelines, Signed-off by [email protected] * fixed pre-commit, Signed-off by [email protected] * minor changes, Signed-off by [email protected] * fixed Ryan Wolf's comments, Signed-off by [email protected] * fixed minor bugs in configs, Signed-off by [email protected] * removed commented code in main.py, Signed-off by [email protected] * added CLI flags for generation & filtering removed code duplication, Signed-off by [email protected] * minor fix to quickstart notebook, Signed-off by [email protected] * removed filter.py & generate.py, Signed-off by [email protected] --------- Signed-off-by: viraman <[email protected]> Signed-off-by: Vinay Raman <[email protected]> Co-authored-by: vinay-raman <[email protected]>
1 parent 96a49b7 commit 6a05d29

File tree

9 files changed

+332
-180
lines changed

9 files changed

+332
-180
lines changed

nemo_curator/filters/synthetic.py

+32-19
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424

2525
from nemo_curator.filters.doc_filter import DocumentFilter
2626
from nemo_curator.utils.decorators import batched
27+
from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker
28+
29+
30+
def create_client(base_url, api_key):
31+
openai_client = OpenAI(
32+
base_url=base_url,
33+
api_key=api_key,
34+
)
35+
return openai_client
2736

2837

2938
# ----------------------------------------------------------------------------80
@@ -52,16 +61,21 @@ def __init__(
5261
self.percentile = percentile
5362
if truncate:
5463
self.truncate = truncate
55-
try:
56-
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
57-
except Exception as e:
58-
print(f"Error accessing NIM model: {e}")
5964
self.batch_size = batch_size
6065
self.text_fields = text_fields
6166

6267
@batched
6368
def score_document(self, df: pd.DataFrame):
6469

70+
try:
71+
self.client = load_object_on_worker(
72+
attr="openai_client_easiness",
73+
load_object_function=create_client,
74+
load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key},
75+
)
76+
except NoWorkerError:
77+
return pd.Series(np.ones(len(df)), dtype=float)
78+
6579
document_score = self._calc_similarity_nim(
6680
df[self.text_fields[0]].to_list(), df[self.text_fields[1]].to_list()
6781
)
@@ -90,7 +104,7 @@ def _get_nim_embedding(self, text, input_type):
90104
print(f"Error: {e}")
91105
response = None
92106

93-
if response:
107+
if response and not isinstance(response, str):
94108
if isinstance(text, list):
95109
embeddings = [r.embedding for r in response.data]
96110
elif isinstance(text, str):
@@ -116,9 +130,6 @@ def _calc_similarity_nim(self, context, question):
116130

117131
return sim
118132

119-
def __dask_tokenize__(self):
120-
return normalize_token(EasinessFilter)
121-
122133

123134
# ----------------------------------------------------------------------------80
124135
# ----------------------- Answerability Filter ---------------------------------
@@ -149,19 +160,24 @@ def __init__(
149160
self.system_prompt = answerability_system_prompt
150161
self.user_prompt_template = answerability_user_prompt_template
151162
self.num_criteria = num_criteria
152-
153-
try:
154-
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
155-
except Exception as e:
156-
print(f"Error accessing NIM model: {e}")
157-
158163
self.text_fields = text_fields
159164

160165
@batched
161166
def score_document(self, df: pd.DataFrame):
162-
return df.apply(
167+
168+
try:
169+
self.client = load_object_on_worker(
170+
attr="openai_client_answerability",
171+
load_object_function=create_client,
172+
load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key},
173+
)
174+
except NoWorkerError:
175+
return pd.Series(["string"] * len(df))
176+
177+
return df.progress_apply(
163178
lambda row: self._llm_as_judge(
164-
row[self.text_fields[0]], row[self.text_fields[1]]
179+
row[self.text_fields[0]],
180+
row[self.text_fields[1]],
165181
),
166182
axis=1,
167183
)
@@ -212,8 +228,5 @@ def _llm_as_judge(self, context: str, question: str):
212228

213229
return generation
214230

215-
def __dask_tokenize__(self):
216-
return normalize_token(AnswerabilityFilter)
217-
218231

219232
# ----------------------------------------------------------------------------80

tutorials/nemo-retriever-synthetic-data-generation/README.md

+21-8
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,35 @@ Navigate to the [quick start notebook](notebooks/quickstart.ipynb) and follow th
4545

4646
### Run Pipeline (CLI)
4747

48-
The pipeline can be run with datasets in rawdoc (only text, title and ids if any) format. To test the pipeline, you can use the provided example data at ```sample_data_rawdoc.jsonl```
48+
The pipeline can be run with datasets in ```jsonl``` (only text, title and ids if any) format. To test the pipeline, you can use the provided example data at ```sample_data/sample_data_rawdoc.jsonl```
4949

50-
Navigate to the top level of this project directory and run the following command in your command line. It will take roughly 5-10 minutes.
50+
To use jsonl format, provide your data in a single or multiple `.jsonl` files. The structure of the data should follow this format: `{"text": <document>, "title": <title>}`. Additionally, if the documents already have a document id, the input file can also contain document ids. The same ids will be persisted in the generated data as well. Another accepted format is `{"_id": <document_id>, "text": <document>, "title": <title>}`.
5151

52-
- `Rawdoc format`
53-
54-
To use rawdoc format, provide your data in a `.jsonl` file. The structure of the data should follow this format: `{"text": <document>, "title": <title>}`. Additionally, if the documents already have a document id, the input file can also contain document ids. The same ids will be persisted in the generated data as well. Another accepted format is `{"_id": <document_id>, "text": <document>, "title": <title>}`.
52+
The pipeline can be run in two modes (1. Generation and 2. Filtering). In order to run the full pipeline in generation mode, use the script ```main.py``` with the flag ```--pipeline-type=generate```
53+
```
54+
python tutorials/nemo-retriever-synthetic-data-generation/main.py \
55+
--api-key=<API Key> \
56+
--input-dir=tutorials/nemo-retriever-synthetic-data-generation/sample_data \
57+
--pipeline-config=tutorials/nemo-retriever-synthetic-data-generation/config/config.yaml\
58+
--input-format=jsonl \
59+
--pipeline-type=generate \
60+
--output-dir=tutorials/nemo-retriever-synthetic-data-generation/outputs/sample_data_rawdoc
61+
--save-format=jsonl
62+
--n-partitions=5
63+
```
64+
The data can be saved in two formats (1. jsonl, 2. beir). Additionally, the user can pass ```--n-partitions``` flag to speed-up generation for large datasets.
5565

56-
In order to run the pipeline, use the script ```main.py```
66+
To filter pre-generated data, run ```main.py``` with ```--pipeline-type=filter```
67+
Note the change in the ```input-dir```, we need to use the path to the generated data in jsonl format.
5768
```
5869
python tutorials/nemo-retriever-synthetic-data-generation/main.py \
5970
--api-key=<API Key> \
60-
--input-file=tutorials/nemo-retriever-synthetic-data-generation/data/sample_data_rawdoc.jsonl \
71+
--input-dir= tutorials/nemo-retriever-synthetic-data-generation/outputs/sample_data_rawdoc/jsonl \
6172
--pipeline-config=tutorials/nemo-retriever-synthetic-data-generation/config/config.yaml\
62-
--input-format=rawdoc \
73+
--input-format=jsonl \
74+
--pipeline-type=filter \
6375
--output-dir=tutorials/nemo-retriever-synthetic-data-generation/outputs/sample_data_rawdoc
76+
--save-format=jsonl
6477
```
6578

6679
For more information about the expected structure of the data, see the [quick start notebook](notebooks/quickstart.ipynb).

tutorials/nemo-retriever-synthetic-data-generation/config/config-fiqa.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ generator_system_prompt: |
5555
Do I need a new EIN since I am hiring employees for my LLC?
5656
5757
user_prompt_template: |
58-
Generate {num_questions} questions and corresponding answers based on Input Document.
58+
Generate {n_openlines} questions and corresponding answers based on Input Document.
5959
6060
Input Document:
6161
{document}
@@ -72,7 +72,7 @@ percentile: 70 # Percentile for threshold calculation (float) [0, 100]
7272
batch_size: 1
7373

7474
#Answerability filter (LLM-as-judge)
75-
answerability_filter: "meta/llama3-70b-instruct"
75+
answerability_filter: "meta/llama-3.1-70b-instruct"
7676
num_criteria: 4 # Number of criteria to parse from the response. It must be alined with the prompt template
7777
answerability_system_prompt: |
7878
You are an evaluator who is rating questions to given context passages based on the given criteria. Assess the given question for clarity and answerability given enough domain knowledge, consider the following evaluation criterion:

tutorials/nemo-retriever-synthetic-data-generation/config/config-nq.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ percentile: 70 # Percentile for threshold calculation (float) [0, 100]
7272
batch_size: 1
7373

7474
#Answerability filter (LLM-as-judge)
75-
answerability_filter: "meta/llama3-70b-instruct"
75+
answerability_filter: "meta/llama-3.1-70b-instruct"
7676
num_criteria: 4 # Number of criteria to parse from the response. It must be alined with the prompt template
7777
answerability_system_prompt: |
7878
You are an evaluator who is rating questions to given context passages based on the given criteria. Assess the given question for clarity and answerability given enough domain knowledge, consider the following evaluation criterion:

tutorials/nemo-retriever-synthetic-data-generation/config/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ percentile: 70 # Percentile for threshold calculation (float) [0, 100]
6363
batch_size: 1
6464

6565
#Answerability filter (LLM-as-judge)
66-
answerability_filter: "meta/llama3-70b-instruct"
66+
answerability_filter: "meta/llama-3.1-70b-instruct"
6767
num_criteria: 4 # Number of criteria to parse from the response. It must be alined with the prompt template
6868
answerability_system_prompt: |
6969
You are an evaluator who is rating questions to given context passages based on the given criteria. Assess the given question for clarity and answerability given enough domain knowledge, consider the following evaluation criterion:

0 commit comments

Comments
 (0)