Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add input_output_map and multi-vote #199

Merged
merged 366 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
366 commits
Select commit Hold shift + click to select a range
4c907bc
modify prompt templete
zhaochen20 Jun 28, 2023
91643e5
modify prompt generation
zhaochen20 Jun 28, 2023
4c00db8
add new prompt for generator
zhaochen20 Jun 29, 2023
03c5d5c
add new prompt for generator
zhaochen20 Jun 29, 2023
bd16aa8
add new prompt for generator to solve unexpected Json of Json
zhaochen20 Jun 29, 2023
3bb00c9
delete duplication
zhaochen20 Jun 29, 2023
58943ef
new mock function
zhaochen20 Jun 30, 2023
39d2e1c
Merge branch 'dataset_test_modification' into self-instruct
zhaochen20 Jun 30, 2023
f3d0dac
new prompt template
zhaochen20 Jul 2, 2023
21d83ac
new prompt template for dataset generator
zhaochen20 Jul 2, 2023
92a23aa
add random generted few-shot examples
zhaochen20 Jul 2, 2023
eb608a7
split into two prompt template
zhaochen20 Jul 2, 2023
b49bf3b
Merge branch 'self-instruct' into dataset_test_modification
zhaochen20 Jul 2, 2023
51254ab
new middle template
zhaochen20 Jul 3, 2023
9eb9f10
add left padding
zhaochen20 Jul 4, 2023
7c97880
add max_new_tokens
zhaochen20 Jul 4, 2023
f5a6245
add max_new_tokens
zhaochen20 Jul 4, 2023
cb923b9
add max_new_tokens and control for GPT2
zhaochen20 Jul 5, 2023
96491e8
debug single generation for executor
zhaochen20 Jul 5, 2023
cf12a3c
add new unit test for model executor
zhaochen20 Jul 5, 2023
e5ba60d
add <eos> for processor
zhaochen20 Jul 6, 2023
6d507be
add new test for trainer
zhaochen20 Jul 6, 2023
87fad8c
Refactor model executor
zhaochen20 Jul 6, 2023
fe0cbc8
[use max length]
zhaochen20 Jul 6, 2023
d129d50
[use max length] strict beam search
zhaochen20 Jul 6, 2023
a9946a5
add more constraint on sequence length
zhaochen20 Jul 7, 2023
944fb8c
release an if control in trainer
zhaochen20 Jul 7, 2023
f85c790
Merge branch 'main' of github.com:viswavi/prompt2model
zhaochen20 Jul 7, 2023
d5e038a
Merge branch 'main' of github.com:viswavi/prompt2model
zhaochen20 Jul 7, 2023
cd9806e
fix conflict in gtp2
zhaochen20 Jul 7, 2023
eea35f8
fix conflict in batch enocde plus
zhaochen20 Jul 7, 2023
dd0de57
Merge branch 'use_batch_encode_plus' into rename_evaluator
zhaochen20 Jul 7, 2023
76a9a21
Merge branch 'rename_evaluator' into debug_label
zhaochen20 Jul 7, 2023
07a3cf8
Merge branch 'debug_label' into add_metric_for_autoregressive
zhaochen20 Jul 7, 2023
b81fb90
fix conflict in debug label
zhaochen20 Jul 7, 2023
9eb8b55
Merge branch 'main' into add_metric_for_autoregressive
zhaochen20 Jul 7, 2023
6a7813b
Add more device for model executor (#146)
zhaochenyang20 Jul 7, 2023
1bd29ba
Merge branch 'add_metric_for_autoregressive' of github.com:viswavi/pr…
zhaochen20 Jul 7, 2023
014793a
use small models at all
zhaochen20 Jul 7, 2023
a1cc790
solve conflict from add_metric_for_autoregressive
zhaochen20 Jul 7, 2023
ed14e85
Merge branch 'add_label_in_processor' into openai_key-can-not-found
zhaochen20 Jul 7, 2023
fe00acf
Merge branch 'openai_key-can-not-found' into new_meta_prompt
zhaochen20 Jul 7, 2023
aa0a1fb
Merge branch 'new_meta_prompt' into dataset_test_modification
zhaochen20 Jul 7, 2023
d2a725a
use small model for testing
zhaochen20 Jul 7, 2023
2a80ccf
use small model for testing, fix format
zhaochen20 Jul 7, 2023
55deb9c
inference on cuda
zhaochen20 Jul 7, 2023
cbb1336
split test helper of model creator
zhaochen20 Jul 8, 2023
5d75cf1
revert generator
zhaochen20 Jul 8, 2023
57b5b39
revert generator
zhaochen20 Jul 8, 2023
f65eb83
add warning for none constraints of sequence_max_length
zhaochen20 Jul 8, 2023
651ef89
fix lint
zhaochen20 Jul 8, 2023
4cbc236
fix lint
zhaochen20 Jul 8, 2023
09dc70a
add test for trunction in trainer
zhaochen20 Jul 8, 2023
413aa6a
fix review
zhaochen20 Jul 9, 2023
ecd5d2b
add label to processor (#149)
zhaochenyang20 Jul 9, 2023
d4459f6
Merge branch 'add_metric_for_autoregressive' of github.com:viswavi/pr…
zhaochen20 Jul 9, 2023
41d3485
Merge branch 'add_metric_for_autoregressive' into openai_key-can-not-…
zhaochen20 Jul 9, 2023
3e06625
fix review
zhaochen20 Jul 9, 2023
dd73685
Merge branch 'openai_key-can-not-found' into new_meta_prompt
zhaochen20 Jul 9, 2023
5e69b42
fix review. change name
zhaochen20 Jul 9, 2023
027ebf1
fix review. change name
zhaochen20 Jul 9, 2023
d35c931
fix review. change name
zhaochen20 Jul 9, 2023
ee7bfca
fix review. change comment
zhaochen20 Jul 9, 2023
6a81a5f
Merge branch 'debug_autoregressive' into split_helper_of_model_create
zhaochen20 Jul 9, 2023
7778e6a
Merge branch 'split_helper_of_model_create' into one_executor_fits_all
zhaochen20 Jul 9, 2023
341030d
fix lint
zhaochen20 Jul 10, 2023
114d168
add label unit test
zhaochen20 Jul 10, 2023
bc1d69c
add comments for unit tests
zhaochen20 Jul 10, 2023
4719877
add comments for the magic number -100
zhaochen20 Jul 10, 2023
5e4341b
add delimiter for processor, add new test for tokenization of trainer
zhaochen20 Jul 10, 2023
ad35626
fix typecheck for 3.9 and 3.10
zhaochen20 Jul 10, 2023
bbbaf26
add , add new unit test for tokenize dataset and add comments.
zhaochen20 Jul 10, 2023
0122580
greedy decode
zhaochen20 Jul 11, 2023
34a1cc1
greedy decode
zhaochen20 Jul 11, 2023
061a7e0
fix bug for eos_token_id
zhaochen20 Jul 11, 2023
d3a7d56
add repetition_penalty
zhaochen20 Jul 11, 2023
144e2c6
add print to prediction
zhaochen20 Jul 11, 2023
d310119
new tests for autoregressive model
zhaochen20 Jul 11, 2023
badeaef
new tests for autoregressive model
zhaochen20 Jul 11, 2023
10936bf
fix comment
zhaochen20 Jul 12, 2023
f1a0b91
fix typo
zhaochen20 Jul 12, 2023
45dfce6
Merge branch 'add_metric_for_autoregressive' into openai_key-can-not-…
zhaochen20 Jul 12, 2023
97945eb
fix assert, add new unit test
zhaochen20 Jul 12, 2023
bfbac91
Merge branch 'main' of github.com:viswavi/prompt2model
zhaochen20 Jul 12, 2023
1e62722
fix conflict
zhaochen20 Jul 12, 2023
47bcd84
Merge branch 'openai_key-can-not-found' into new_meta_prompt
zhaochen20 Jul 12, 2023
8c55a31
Merge branch 'new_meta_prompt' into dataset_test_modification
zhaochen20 Jul 12, 2023
84f5cfd
fix conflict
zhaochen20 Jul 12, 2023
3f428a3
Merge branch 'debug_autoregressive' into split_helper_of_model_create
zhaochen20 Jul 12, 2023
cb0ad1f
Merge branch 'split_helper_of_model_create' into one_executor_fits_all
zhaochen20 Jul 12, 2023
951d854
Merge branch 'one_executor_fits_all' into fix_label_for_decoder_only
zhaochen20 Jul 12, 2023
42db8ac
add new test cases and real test for trainer
zhaochen20 Jul 12, 2023
5ca3dfe
delete confidence
zhaochen20 Jul 13, 2023
a010399
refactor unit tests
zhaochen20 Jul 13, 2023
8351ed3
add 5 generation strategy
zhaochen20 Jul 13, 2023
6e0411f
add 5 generation strategy
zhaochen20 Jul 13, 2023
20928b0
fix OOM problem of trainer test
zhaochen20 Jul 13, 2023
4f662f2
fix OOM problem of trainer test
zhaochen20 Jul 13, 2023
946fe87
fix OOM problem of trainer test
zhaochen20 Jul 13, 2023
1ea151e
delete unecessary eos
zhaochen20 Jul 13, 2023
0b21140
delete unecessary eos
zhaochen20 Jul 13, 2023
8f89d0c
Merge branch 'main' of github.com:viswavi/prompt2model
zhaochen20 Jul 13, 2023
4876937
fix conflict
zhaochen20 Jul 13, 2023
27a3aaa
Merge branch 'dataset_test_modification' into debug_autoregressive
zhaochen20 Jul 13, 2023
394c105
Merge branch 'debug_autoregressive' into split_helper_of_model_create
zhaochen20 Jul 13, 2023
d0b8c92
Merge branch 'split_helper_of_model_create' into one_executor_fits_all
zhaochen20 Jul 13, 2023
29dd565
Merge branch 'one_executor_fits_all' into fix_label_for_decoder_only
zhaochen20 Jul 13, 2023
cde3df8
MMerge branch 'fix_label_for_decoder_only' into split_model_trainer_test
zhaochen20 Jul 13, 2023
7d735c3
fix conflict
zhaochen20 Jul 13, 2023
3734cbb
Merge branch 'calculate_confidence' into delete_unecessary_eos
zhaochen20 Jul 13, 2023
3466554
refactor evaluator
zhaochen20 Jul 14, 2023
cf15354
use sub set
zhaochen20 Jul 14, 2023
0b28f23
fix oom
zhaochen20 Jul 14, 2023
50a3b11
fix oom
zhaochen20 Jul 14, 2023
33513f8
add comments for evaluator
zhaochen20 Jul 14, 2023
1056c51
add new comments for the processor
zhaochen20 Jul 15, 2023
3a42674
Merge branch 'delete_unecessary_eos' into refactor_evaluator
zhaochen20 Jul 15, 2023
33173f2
refactor model executor
zhaochen20 Jul 15, 2023
e4f59a2
refactor training callback
zhaochen20 Jul 16, 2023
58cf094
gc.collect()
zhaochen20 Jul 16, 2023
d0f7a09
use gc.collect() to fix oom
zhaochen20 Jul 16, 2023
a0a8e9f
concatenate tests for trainer
zhaochen20 Jul 16, 2023
edb18fa
fix multi in unit tests
zhaochen20 Jul 17, 2023
c98b436
Component document (#132)
zhaochenyang20 Jul 18, 2023
028d77c
Merge branch 'refactor_callback' of github.com:viswavi/prompt2model i…
zhaochen20 Jul 18, 2023
c4c224d
Merge branch 'dataset_test_modification' of github.com:viswavi/prompt…
zhaochen20 Jul 18, 2023
4a5dac9
refactor generator
zhaochen20 Jul 18, 2023
fc268d4
Merge branch 'dataset_test_modification' into refactor_callback
zhaochen20 Jul 18, 2023
fd45d4f
refactor generator and add comments
zhaochen20 Jul 18, 2023
a1d262d
Merge branch 'dataset_test_modification' into refactor_callback
zhaochen20 Jul 18, 2023
fec7199
fix lint
zhaochen20 Jul 18, 2023
ff91e93
Merge branch 'dataset_test_modification' into refactor_callback
zhaochen20 Jul 18, 2023
6d4f1e8
fix lint
zhaochen20 Jul 18, 2023
ba64551
add new meta_prompt
zhaochen20 Jul 18, 2023
7f3d89f
Merge branch 'dataset_test_modification' into refactor_callback
zhaochen20 Jul 18, 2023
d71addc
add new meta examples
zhaochen20 Jul 19, 2023
2da9157
Merge branch 'dataset_test_modification' into refactor_callback
zhaochen20 Jul 19, 2023
e22b46e
change temperature
zhaochen20 Jul 19, 2023
faafe34
add new paramerters for generator
zhaochen20 Jul 19, 2023
cc70ed9
fix lint
zhaochen20 Jul 19, 2023
57f3d76
fix lint
zhaochen20 Jul 19, 2023
ed05b8e
add constraints to prompt too long
zhaochen20 Jul 19, 2023
28d1d11
Merge branch 'main' of github.com:viswavi/prompt2model
zhaochen20 Jul 20, 2023
4291046
fix confilcts
zhaochen20 Jul 20, 2023
4f442c5
use zeno in a batch
zhaochen20 Jul 20, 2023
78afc34
Merge branch 'use_zeno' into debug_autoregressive
zhaochen20 Jul 20, 2023
43c6079
fix confilits
zhaochen20 Jul 20, 2023
094dc70
Merge branch 'split_helper_of_model_create' into one_executor_fits_all
zhaochen20 Jul 20, 2023
3d4d450
Merge branch 'one_executor_fits_all' into fix_label_for_decoder_only
zhaochen20 Jul 20, 2023
44cc6f1
Merge branch 'fix_label_for_decoder_only' into split_model_trainer_test
zhaochen20 Jul 20, 2023
e6a12d9
fix lint
zhaochen20 Jul 20, 2023
d4d68cd
add gc.collect()
zhaochen20 Jul 20, 2023
fe6f20f
Merge branch 'split_model_trainer_test' into calculate_confidence
zhaochen20 Jul 20, 2023
a08b38f
add assertion for examples
zhaochen20 Jul 20, 2023
234ce93
fix conflicts
zhaochen20 Jul 20, 2023
f62a609
Merge branch 'refactor_evaluator' into refactor_executor
zhaochen20 Jul 20, 2023
c6a1741
fix conflicts
zhaochen20 Jul 20, 2023
72d712f
add error control
zhaochen20 Jul 21, 2023
f5c9ba7
Merge branch 'use_zeno' into debug_autoregressive
zhaochen20 Jul 21, 2023
3a78d11
Merge branch 'debug_autoregressive' into split_helper_of_model_create
zhaochen20 Jul 21, 2023
230de92
Merge branch 'split_helper_of_model_create' into split_model_trainer_…
zhaochen20 Jul 21, 2023
36b365f
Merge branch 'split_model_trainer_test' into calculate_confidence
zhaochen20 Jul 21, 2023
ed9adf1
Merge branch 'calculate_confidence' into refactor_evaluator
zhaochen20 Jul 21, 2023
096b053
Merge branch 'refactor_evaluator' into refactor_executor
zhaochen20 Jul 21, 2023
55420f2
Merge branch 'refactor_executor' into refactor_callback
zhaochen20 Jul 21, 2023
6c5a883
Merge branch 'refactor_callback' of github.com:neulab/prompt2model in…
zhaochen20 Jul 21, 2023
12b2335
Merge branch 'refactor_executor' of github.com:neulab/prompt2model in…
zhaochen20 Jul 21, 2023
b4dfce5
Merge branch 'refactor_executor' into refactor_callback
zhaochen20 Jul 21, 2023
5906386
fetech remote branch of `split_model_trainer_test`
zhaochen20 Jul 21, 2023
61a97a4
Merge to reduce abnormal diff
zhaochen20 Jul 21, 2023
af1a37f
Fetech error_handler
zhaochen20 Jul 21, 2023
53d9368
Fetch remote split_model_trainer_test
zhaochen20 Jul 21, 2023
981ceaa
Merge to fix abnormal diff
zhaochen20 Jul 21, 2023
eb387f2
Merge branch 'debug_autoregressive' of github.com:neulab/prompt2model…
zhaochen20 Jul 21, 2023
6ecc23a
Merge branch 'split_helper_of_model_create' of github.com:neulab/prom…
zhaochen20 Jul 21, 2023
9067cd0
Merge branch 'fix_label' of github.com:neulab/prompt2model into fix_l…
zhaochen20 Jul 21, 2023
be68175
Merge branch 'calculate_confidence' of github.com:neulab/prompt2model…
zhaochen20 Jul 21, 2023
35cabfd
fix error handle changes
zhaochen20 Jul 21, 2023
33f44c5
Merge branch 'refactor_evaluator' of github.com:neulab/prompt2model i…
zhaochen20 Jul 21, 2023
54d5efe
merge test files for trainer
zhaochen20 Jul 21, 2023
21fcbae
Merge branch 'refactor_executor' of github.com:neulab/prompt2model in…
zhaochen20 Jul 21, 2023
02b708e
Merge branch 'refactor_executor' of github.com:neulab/prompt2model in…
zhaochen20 Jul 21, 2023
f20f01e
fix conflicts
zhaochen20 Jul 21, 2023
397e75c
add tests for t5
zhaochen20 Jul 21, 2023
fa9b86d
add zeno's new release
zhaochen20 Jul 22, 2023
6ecb857
add multi-generation and debug
zhaochen20 Jul 22, 2023
26e1fc8
[set base parameter]
zhaochen20 Jul 23, 2023
87e9a15
add requests_per_minute and responses_per_requests
zhaochen20 Jul 23, 2023
e52c68e
set max_length to 500
zhaochen20 Jul 23, 2023
633e414
delecte wrong max_new_tokens in executor
zhaochen20 Jul 23, 2023
fd26123
distinguish high-equality and low-equality examples for generator
zhaochen20 Jul 23, 2023
63bd1fc
merge main. Fix readme
zhaochen20 Jul 26, 2023
533cf21
merge main. Fix readme
zhaochen20 Jul 26, 2023
f56e581
fetch remote branch of stress examples
zhaochen20 Jul 26, 2023
0d62481
add new member variables
zhaochen20 Jul 26, 2023
682045b
add input_output_map and multi-vote
zhaochen20 Jul 26, 2023
9b01717
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 5, 2023
ab195a0
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 8, 2023
aa66358
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 8, 2023
b8c9cf5
fix conflicts
zhaochen20 Aug 8, 2023
4c77c71
change the labeling logics
zhaochen20 Aug 8, 2023
eec2ad2
Merge branch 'fix_eos_label_adding' into fix_label
zhaochen20 Aug 8, 2023
b9cb3d5
add truncation warning for executor
zhaochen20 Aug 8, 2023
ed0362d
erge branch 'warning_executor' into fix_label
zhaochen20 Aug 8, 2023
57488a0
use execption in dataset generator
zhaochen20 Aug 8, 2023
3d6f0c6
use execption in dataset generator
zhaochen20 Aug 8, 2023
e43a794
use execption in dataset generator
zhaochen20 Aug 8, 2023
8cc6e60
merge main
zhaochen20 Aug 8, 2023
cbc48de
Merge branch 'calculate_confidence' into delete_unecessary_eos
zhaochen20 Aug 8, 2023
7ab073d
Merge branch 'delete_unecessary_eos' into refactor_evaluator
zhaochen20 Aug 8, 2023
ff8e73f
Merge branch 'refactor_evaluator' into refactor_executor
zhaochen20 Aug 8, 2023
a889528
merge main
zhaochen20 Aug 8, 2023
5859b73
merge main
zhaochen20 Aug 8, 2023
ee14c60
Merge branch 'refactor_callback' into add_multi_generation
zhaochen20 Aug 8, 2023
8893c5c
merge main
zhaochen20 Aug 8, 2023
b41b899
merge main
zhaochen20 Aug 8, 2023
5a4fdb0
Merge branch 'stress_original_examples' into add_new_member_variables
zhaochen20 Aug 8, 2023
f46ab8f
update docstring
Aug 8, 2023
f240ce4
fix review from graham
zhaochen20 Aug 8, 2023
aad558a
fix review from graham
zhaochen20 Aug 8, 2023
b35a3d7
Update tests/dataset_generator_with_filter_test.py
Aug 8, 2023
439a617
Merge branch 'add_new_member_variables' into input_output_map
zhaochen20 Aug 8, 2023
e18c042
use none stateful function
zhaochen20 Aug 8, 2023
a8b7d20
Merge branch 'input_output_map' of github.com:neulab/prompt2model int…
zhaochen20 Aug 8, 2023
2857b73
add new member variables (#198)
Aug 8, 2023
f25c905
use none stateful function for extract responses
zhaochen20 Aug 8, 2023
ecf8c4d
add new comments
zhaochen20 Aug 8, 2023
7e1d8d7
fix grammar
zhaochen20 Aug 8, 2023
98c6580
Update test_helpers/dataset_tools.py
Aug 8, 2023
8ad759c
Update test_helpers/dataset_tools.py
Aug 8, 2023
4980938
Merge branch 'input_output_map' of github.com:neulab/prompt2model int…
zhaochen20 Aug 8, 2023
97a693d
fix lint
zhaochen20 Aug 8, 2023
1c023eb
fix wrong comments
zhaochen20 Aug 8, 2023
d4612fb
Update prompt2model/dataset_generator/openai_gpt.py
Aug 8, 2023
df1e8cd
Merge branch 'input_output_map' of github.com:neulab/prompt2model int…
zhaochen20 Aug 8, 2023
315d198
fix lint
zhaochen20 Aug 8, 2023
bc5c993
[fetch]
zhaochen20 Aug 10, 2023
ced20c6
fix grammar error
zhaochen20 Aug 10, 2023
89901c9
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 11, 2023
cf44387
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 11, 2023
2a7328e
[fetch]
zhaochen20 Aug 11, 2023
7dea621
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 11, 2023
4c0594b
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 11, 2023
2b924b4
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 11, 2023
3aef8ac
fix Graham's review
zhaochen20 Aug 11, 2023
2226b97
fix conflict
zhaochen20 Aug 11, 2023
118b0d9
fix conflict
zhaochen20 Aug 11, 2023
86175a3
fix conflict
zhaochen20 Aug 11, 2023
6172ded
Merge branch 'main' of github.com:neulab/prompt2model
zhaochen20 Aug 11, 2023
a1462ce
fix conflict
zhaochen20 Aug 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 170 additions & 55 deletions prompt2model/dataset_generator/openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,33 +95,12 @@ def __init__(
self.requests_per_minute = requests_per_minute
self.filter_duplicated_examples = filter_duplicated_examples
self.cache_root = Path(cache_root)
# This list stores all generated examples. These will later be
# converted into `generated_dataset` and `input_output_map`
# if `filter_duplicated_examples` is True.
self.generated_examples: list[Example] = []

# `generated_examples` will be transformed into `generated_dataset`.
# If `filter_duplicated_examples` is True, `generated_examples` will
# be filtered based on multi-votes before being used to construct
# `generated_dataset`. If it's False, `generated_examples` will be
# used directly to construct `generated_dataset`.
self.generated_dataset: Dataset = Dataset.from_dict({})

# If `filter_duplicated_examples` is True, `self.generated_examples`
# will first be converted into `input_output_map`, and then into
# `generated_dataset`. If it's False, `input_output_map` will remain
# empty.
self.input_output_map: dict[str, Counter] = defaultdict(Counter)

# `generating_split` refers to the DatasetSplit currently being
# generated. After each loop, `generated_examples` will be
# stored as a Dataset at the path `{cache_root}/{generating_split}`.
self.generating_split: DatasetSplit | None = None

def generate_prompt(
self,
instruction: str,
few_shot_example_string: str = None,
few_shot_example_string: str,
generated_examples: list[Example],
) -> str:
"""Generates a prompt string.

Expand All @@ -130,43 +109,44 @@ def generate_prompt(
few_shot_example_string: A string representing the few-shot examples
parsed from the user's prompt, which quality is higher than the
genrated examples.
generated_examples: A list of currently generated examples.

Returns:
The generated prompt string.
"""
# The random_example_string is a string, which contains several random
# few-shot examples as demonstrations for the DatasetGenerator. If
# self.generated_examples is empty, then the random_example_string
# generated_examples is empty, then the random_example_string
# is the few-shot examples parsed from the user's prompt.
while True:
if len(self.generated_examples) == 0:
if len(generated_examples) == 0:
low_quality_example_string = "N/A\n"
# Create default low_quality_example_string if self.generated_examples
# Create default low_quality_example_string if generated_examples
# is empty. few_shot_example_string is the high-quality few-shot
# examples parsed from the user's prompt. But if user does not
# provideany examples in the input prompt, the few_shot_example_string
# will be "N/A"/""/None.
random_selected_generated_example_num = 0
# random_selected_generated_example_num is the number of selected
# random examples from self.generated_examples that will be used to
# random examples from generated_examples that will be used to
# create the low_quality_example_string. If generated_examples
# is empty, then random_selected_generated_example_num is 0.
else:
# If self.generated_examples is not empty, low_quality_example_string
# is sveral random generated examples from self.generated_examples.
# If generated_examples is not empty, low_quality_example_string
# is sveral random generated examples from generated_examples.

low_quality_example_string = ""
random_selected_generated_example_num = random.randint(
1, min(len(self.generated_examples), 10)
1, min(len(generated_examples), 10)
)
# random_selected_generated_example_num is the number of selected
# random examples from self.generated_examples that will
# random examples from generated_examples that will
# be concatenated to create low_quality_example_string.
random_examples = random.sample(
self.generated_examples, random_selected_generated_example_num
generated_examples, random_selected_generated_example_num
)
# If generated_examples is not empty, then choose several
# random examples from self.generated_examples to construct
# random examples from generated_examples to construct
# new low_quality_example_string.
for example in random_examples:
low_quality_example_string += (
Expand Down Expand Up @@ -194,21 +174,161 @@ def generate_prompt(
else:
continue

def extract_responses(self, completions: list[openai.Completion]) -> None:
def construct_input_output_map(
self,
generated_examples: list[Example],
) -> dict[str, Counter]:
"""Constructs a dictionary mapping inputs to `Counter` objects of outputs.

Args:
generated_examples: A list of currently generated examples.

Ideally, each input should have a unique output (one-to-one mapping).
However, language models may occasionally generate different outputs
for identical inputs. For instance, given the input “What is the biggest
city in China?”, it might produce different but correct outputs such as
“Shanghai” and “The biggest city in China is Shanghai”. At other times,
it may produce incorrect variations. For the input “What is the Chemical
symbol of gold?”, the outputs might be “Au”, “Au”, and “AU”, where the
last one is wrong due to capital letters.

To address this, OpenAIDataSetGenerator uses a two-step multi-vote
filtering mechanism. This function represents the first step, creating a
dictionary to map inputs to a `Counter` of their outputs.

The function iterates over all the examples, building a dictionary where
inputs serve as keys and `Counter` objects as values. The `Counter`
tracks the frequency of each output for a specific input.

For example:
input: ["apple", "banana", "apple", "orange", "apple"]
output: ["A", "B", "A", "O", "D"]

Then input_output_map value is:
{
"apple": Counter({"A": 2, "D": 1}),
"banana": Counter({"B": 1}),
"orange": Counter({"O": 1})
}
"""
input_output_map: dict[str, Counter] = defaultdict(Counter)

# Iterate through the examples and construct the mapping.
for example in generated_examples:
input_str = example.input_col
output_str = example.output_col

# Increment the count of the output for the specific input.
input_output_map[input_str][output_str] += 1

# Ensure that the generated_examples list is not empty
# and the map is constructed correctly.
if len(generated_examples) != 0:
assert input_output_map

return input_output_map

def apply_multi_vote_to_construct_generated_dataset(
self, input_output_map: dict[str, Counter]
) -> Dataset:
"""Multi-vote to construct generated_dataset from input_output_map.

Args:
generated_examples: A list of currently generated examples.

This method uses multi-vote filtering to create a unique mapping from inputs
to outputs. The input_col of generated_dataset contains unique inputs,
while the output_col holds the shortest, most frequent output for the
corresponding input.

The function asserts that self.filter_duplicated_examples is True and that
input_output_map is not None when generated_examples is not
empty. It then iterates over input_output_map, finding the most frequent
output for each input. If there are multiple outputs with the highest frequency,
it selects the shortest one. If there are multiple shortest outputs with the
highest frequency, it selects the one that comes first in lexicographical
(alphabetical) order.

Example:
Suppose input_output_map is:
{
"apple": Counter({"A": 2, "D": 2}),
"banana": Counter({"B": 2, "C": 1}),
"orange": Counter({"O": 1})
}

The function will produce generated_dataset:
{
"input_col": ["apple", "banana", "orange"],
"output_col": ["A", "B", "O"]
}

Note: When generated_examples is empty, both input_output_map
and generated_dataset will be empty.

Returns:
Currently generated dataset with multi-vote filtering applied.
"""
# Ensure that multi-vote filtering is enabled.
assert self.filter_duplicated_examples

filtered_inputs = []
filtered_outputs = []

for input_str, output_counter in input_output_map.items():
# Find the most frequent output count.
most_common_count = output_counter.most_common(1)[0][1]

# Get all the outputs that have the most common count.
most_frequent_outputs = [
output
for output, count in output_counter.items()
if count == most_common_count
]

# Sort the outputs based on their lengths and select
# the shortest ones. When several outputs have the
# same length with the highest frequency, they will
# be sorted in their lexicographical (alphabetical) order.
most_frequent_outputs.sort(key=len)
final_output = most_frequent_outputs[0]

filtered_inputs.append(input_str)
filtered_outputs.append(final_output)

# Note that when `generated_examples` is empty,
# `input_output_map` is None, and `generated_dataset`
# will also be empty.
generated_dataset = Dataset.from_dict(
{"input_col": filtered_inputs, "output_col": filtered_outputs}
)

# `generated_examples` will be transformed into `generated_dataset`.
# If `filter_duplicated_examples` is True, `generated_examples` will
# be filtered based on multi-votes before being used to construct
# `generated_dataset`. If it's False, `generated_examples` will be
# used directly to construct `generated_dataset`.
return generated_dataset

def extract_responses(
self, completions: list[openai.Completion], generated_examples: list[Example]
) -> list[Example]:
"""Extracts the generated sample and annotation from an OpenAI API response.

Args:
completions: The generated completion objects returned by OpenAI API.
generated_examples: Currently generated examples of DatasetGenerator.

Returns:
A list of `Example` objects.
Each API call will return `responses_per_request` completion objects.
If the response is a valid JSON object, create a namedtuple called
`example` and append it to self.generated_examples. `example` consists
`example` and append it to generated_examples. `example` consists
of `input_col` and`output_col`, where:
- input_col is the generated example string extracted from the response.
- output_col is the generated label string extracted from the response.
If the response is not a valid JSON object, discard it.
There is 5 * len(completions) responses at a time.
There is responses_per_request * len(completions) responses at a time.
"""
for completion in completions:
try:
Expand All @@ -230,14 +350,15 @@ def extract_responses(self, completions: list[openai.Completion]) -> None:
continue
input = str(response_json["input"]).strip()
output = str(response_json["output"]).strip()
self.generated_examples.append(Example(input, output))
generated_examples.append(Example(input, output))
logging.info(f"input: \n\n{input}\n\n")
logging.info(f"output: \n\n{output}\n\n")
except Exception:
logging.warning(
f"Error happened when parsing API completion: {completion}"
)
continue
return generated_examples

async def generate_responses(
self, chat_api: ChatGPTAgent, prompts: list[str]
Expand Down Expand Up @@ -280,9 +401,9 @@ def generate_dataset_split(
"""
_ = split # suppress unused variable warnings
chat_api = ChatGPTAgent(self.api_key)
self.generated_examples = []
generated_examples: list[Example] = []
pbar = tqdm(total=expected_num_examples, desc="Generating examples")
while len(self.generated_examples) < expected_num_examples:
while len(generated_examples) < expected_num_examples:
try:
if self.max_api_calls and self.api_call_counter >= self.max_api_calls:
logging.warning("Maximum number of API calls reached.")
Expand All @@ -293,10 +414,7 @@ def generate_dataset_split(
self.batch_size,
math.ceil(
(
(
expected_num_examples
- len(self.generated_examples)
)
(expected_num_examples - len(generated_examples))
/ self.responses_per_request
)
),
Expand All @@ -306,10 +424,7 @@ def generate_dataset_split(
self.batch_size,
math.ceil(
(
(
expected_num_examples
- len(self.generated_examples)
)
(expected_num_examples - len(generated_examples))
/ self.responses_per_request
)
),
Expand All @@ -322,6 +437,7 @@ def generate_dataset_split(
self.generate_prompt(
instruction=prompt_spec.instruction,
few_shot_example_string=prompt_spec.examples,
generated_examples=generated_examples,
)
for _ in range(batch_size)
]
Expand All @@ -330,22 +446,21 @@ def generate_dataset_split(
responses = loop.run_until_complete(
self.generate_responses(chat_api, prompts)
)
self.extract_responses(responses)
pbar.update(len(self.generated_examples) - pbar.n)
generated_examples = self.extract_responses(
responses, generated_examples=generated_examples
)
pbar.update(len(generated_examples) - pbar.n)
except OPENAI_ERRORS as e:
self.api_call_counter = handle_openai_error(e, self.api_call_counter)
# Each API call will return `responses_per_request` completion
# objects. The upper bound of the length of generated dataset
# is expected_num_examples + responses_per_request.
assert (
len(self.generated_examples)
< expected_num_examples + self.responses_per_request
len(generated_examples) < expected_num_examples + self.responses_per_request
)
return Dataset.from_dict(
{
"input_col": [example.input_col for example in self.generated_examples],
"output_col": [
example.output_col for example in self.generated_examples
],
"input_col": [example.input_col for example in generated_examples],
"output_col": [example.output_col for example in generated_examples],
}
)
10 changes: 8 additions & 2 deletions test_helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Import mock classes used in unit tests."""
from test_helpers.dataset_tools import (
are_dataset_dicts_identical,
are_datasets_identical,
)
from test_helpers.mock_openai import (
MockCompletion,
mock_batch_openai_response,
mock_batch_openai_response_with_identical_completions,
mock_one_openai_response,
)
from test_helpers.model_and_tokenizer import (
Expand All @@ -14,5 +18,7 @@
"create_gpt2_model_and_tokenizer",
"create_t5_model_and_tokenizer",
"mock_one_openai_response",
"mock_batch_openai_response",
"mock_batch_openai_response_with_identical_completions",
"are_dataset_dicts_identical",
"are_datasets_identical",
)
Loading