Skip to content

Commit 28b0ec9

Browse files
committed
upload initial code
This is the initial code of RAFT implementation.
1 parent fb8373b commit 28b0ec9

14 files changed

+706
-0
lines changed

.DS_Store

8 KB
Binary file not shown.

annotate_data/.DS_Store

6 KB
Binary file not shown.

annotate_data/get_rewards.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import json
2+
import os
3+
from dataclasses import dataclass, field
4+
from typing import Optional
5+
import numpy as np
6+
import torch
7+
from datasets import load_dataset
8+
from tqdm import tqdm
9+
from transformers import AutoTokenizer, HfArgumentParser, pipeline
10+
from accelerate import Accelerator
11+
12+
tqdm.pandas()
13+
14+
#####
15+
# This script takes a dataset as the input, where each sample is {"prompt": "the pormpt", "responses": ["response1", "response2", "response3", ...]}
16+
# The script will compute the reward for each input-output pair, and eventually output a new dataset, where each sample contains {"prompt": "the pormpt", "responses": ["response1", "response2", "response3", ...], "rewards": [reward1, reward2, ...]}
17+
#####
18+
19+
20+
@dataclass
21+
class ScriptArguments:
22+
"""
23+
The arguments for the DPO training script.
24+
"""
25+
26+
dataset_name_or_path: Optional[str] = field(
27+
default="uf_split0_responses_K8.jsonl",
28+
metadata={"help": "the location of the dataset name or path"},
29+
)
30+
output_dir: Optional[str] = field(
31+
default="uf_split0_responses_K8_reward.json",
32+
metadata={"help": "the location of the output file"},
33+
)
34+
record_dir: Optional[str] = field(
35+
default=None,
36+
metadata={"help": "the location of the recording file"},
37+
)
38+
reward_name_or_path: Optional[str] = field(
39+
default="sfairXC/FsfairX-LLaMA3-RM-v0.1",
40+
metadata={"help": "the name of the reward model"},
41+
)
42+
input_output_delimiter: Optional[str] = field(
43+
default="",
44+
metadata={"help": "the delimiter between input and output"},
45+
)
46+
K: Optional[int] = field(
47+
default=8,
48+
metadata={"help": "the number of responses per prompt"},
49+
)
50+
51+
52+
accelerator = Accelerator()
53+
54+
parser = HfArgumentParser(ScriptArguments)
55+
script_args = parser.parse_args_into_dataclasses()[0]
56+
57+
device = accelerator.device
58+
pipe_kwargs = {
59+
"return_all_scores": True,
60+
"function_to_apply": "none",
61+
"batch_size": 1,
62+
}
63+
reward_model = script_args.reward_name_or_path
64+
rm_tokenizer = AutoTokenizer.from_pretrained(reward_model)
65+
rm_pipe = pipeline(
66+
"sentiment-analysis",
67+
model=reward_model,
68+
device=device,
69+
tokenizer=rm_tokenizer,
70+
model_kwargs={"torch_dtype": torch.bfloat16},
71+
truncation=True,
72+
)
73+
74+
75+
ds_dir = script_args.dataset_name_or_path
76+
world_size = int(os.getenv("WORLD_SIZE", "1"))
77+
ds = load_dataset("json", data_files=ds_dir, split="train")
78+
79+
local_rank = Accelerator().local_process_index
80+
81+
data_size = len(ds["prompt"])
82+
83+
share = int(data_size / world_size) + 1
84+
ds = ds.select(np.arange(local_rank * share, min((local_rank + 1) * share, len(ds))))
85+
86+
"""
87+
We process the data format here and query the reward model to get the rewards.
88+
"""
89+
90+
91+
def get_reward(test_texts):
92+
pipe_outputs = rm_pipe(test_texts, **pipe_kwargs)
93+
rewards = [output[0]["score"] for output in pipe_outputs]
94+
return rewards
95+
96+
97+
def change_of_format(prom, resp):
98+
# To be modified according to the reward model and the LLM you use
99+
# Be careful about multi-turn conversions
100+
"""
101+
prom = prom.replace("<s>GPT4 Correct User: ", "").replace("<|end_of_turn|>GPT4 Correct Assistant:", "")
102+
103+
final_resp = resp.split("GPT4 Correct User")[0]
104+
"""
105+
message = prom + [{"role": "assistant", "content": resp}]
106+
return rm_tokenizer.apply_chat_template(message, tokenize=False).replace(rm_tokenizer.bos_token, "")
107+
108+
109+
data = []
110+
111+
# tqdm is used to show the progress bar
112+
with torch.no_grad():
113+
for sample in tqdm(ds):
114+
# The VLLM may not generate responses for some prompts because it is too long, we skip them
115+
if len(sample["responses"]) < script_args.K:
116+
continue
117+
test_texts = [change_of_format(sample['prompt'], tmp_output) for tmp_output in sample['responses']]
118+
119+
rewards = get_reward(test_texts)
120+
data.append({"prompt": sample["prompt"], "responses": sample["responses"], "rewards": rewards})
121+
122+
123+
# Send the data to other GPUs
124+
world_size = int(os.getenv("WORLD_SIZE", "1"))
125+
all_process_list = [{}] * world_size
126+
127+
data_to_send = {
128+
"data": [[data[i]] for i in range(len(data))],
129+
}
130+
131+
import torch.distributed as dist
132+
133+
dist.all_gather_object(all_process_list, data_to_send)
134+
gathered_data = []
135+
136+
137+
for i in range(world_size):
138+
tmp_data = [tmp[0] for tmp in all_process_list[i]["data"]]
139+
gathered_data.extend(tmp_data)
140+
141+
all_rewards = [sample["rewards"] for sample in gathered_data]
142+
top1_scores = np.mean(np.max(all_rewards, axis=1))
143+
mean_scores = np.mean(all_rewards)
144+
145+
146+
if local_rank == 0:
147+
print(
148+
"Collect {} data from {} inputs. mean score {} top1 score: {}".format(
149+
len(gathered_data), data_size, mean_scores, top1_scores
150+
)
151+
)
152+
if len(gathered_data) < data_size:
153+
print(
154+
"Some of the prompts are with responses < {}. This can happen because the prompt is too long and is ignored by VLLM".format(
155+
script_args.K
156+
)
157+
)
158+
159+
with open(script_args.output_dir, "w", encoding="utf8") as f:
160+
for i in range(len(gathered_data)):
161+
json.dump(gathered_data[i], f, ensure_ascii=False)
162+
f.write('\n')
163+
164+
if script_args.record_dir is not None:
165+
with open(script_args.record_dir, "a") as f:
166+
f.write(str(mean_scores) + "\t" + str(top1_scores) + "\n")

configs/.DS_Store

6 KB
Binary file not shown.

configs/deepspeed_stage1.json

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"zero_optimization": {
3+
"stage": 1,
4+
"overlap_comm": true
5+
},
6+
"bf16": {
7+
"enabled": "auto"
8+
},
9+
"fp16": {
10+
"enabled": "auto",
11+
"auto_cast": false,
12+
"loss_scale": 0,
13+
"initial_scale_power": 32,
14+
"loss_scale_window": 1000,
15+
"hysteresis": 2,
16+
"min_loss_scale": 1
17+
},
18+
"gradient_accumulation_steps": "auto",
19+
"gradient_clipping": "auto",
20+
"train_batch_size": "auto",
21+
"train_micro_batch_size_per_gpu": "auto",
22+
"wall_clock_breakdown": false
23+
}

configs/deepspeed_stage2.json

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"zero_optimization": {
3+
"stage": 2,
4+
"offload_optimizer": {
5+
"device": "cpu"
6+
},
7+
"contiguous_gradients": true,
8+
"overlap_comm": true
9+
},
10+
"bf16": {
11+
"enabled": "auto"
12+
},
13+
"fp16": {
14+
"enabled": "auto",
15+
"auto_cast": false,
16+
"loss_scale": 0,
17+
"initial_scale_power": 32,
18+
"loss_scale_window": 1000,
19+
"hysteresis": 2,
20+
"min_loss_scale": 1
21+
},
22+
"gradient_accumulation_steps": "auto",
23+
"gradient_clipping": "auto",
24+
"train_batch_size": "auto",
25+
"train_micro_batch_size_per_gpu": "auto",
26+
"wall_clock_breakdown": false
27+
}

configs/deepspeed_stage3.json

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"zero_optimization": {
3+
"stage": 3,
4+
"overlap_comm": true,
5+
"contiguous_gradients": true,
6+
"sub_group_size": 0,
7+
"reduce_bucket_size": "auto",
8+
"stage3_prefetch_bucket_size": "auto",
9+
"stage3_param_persistence_threshold": "auto",
10+
"stage3_max_live_parameters": 0,
11+
"stage3_max_reuse_distance": 0,
12+
"stage3_gather_16bit_weights_on_model_save": true
13+
},
14+
"bf16": {
15+
"enabled": true
16+
},
17+
"fp16": {
18+
"enabled": "auto",
19+
"auto_cast": false,
20+
"loss_scale": 0,
21+
"initial_scale_power": 32,
22+
"loss_scale_window": 1000,
23+
"hysteresis": 2,
24+
"min_loss_scale": 1
25+
},
26+
"gradient_accumulation_steps": "auto",
27+
"gradient_clipping": "auto",
28+
"train_batch_size": "auto",
29+
"train_micro_batch_size_per_gpu": "auto",
30+
"wall_clock_breakdown": false
31+
}

generation/.DS_Store

6 KB
Binary file not shown.

generation/gen_hf.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import json
2+
from dataclasses import dataclass, field
3+
from typing import List, Optional
4+
from datasets import load_dataset
5+
from tqdm import tqdm
6+
from transformers import AutoTokenizer, HfArgumentParser
7+
from concurrent.futures import ThreadPoolExecutor, as_completed
8+
import requests
9+
10+
tqdm.pandas()
11+
12+
13+
@dataclass
14+
class ScriptArguments:
15+
"""
16+
The arguments for the DPO training script.
17+
"""
18+
19+
url: Optional[str] = field(
20+
default="http://localhost",
21+
metadata={"help": "url of the model response"},
22+
)
23+
tokenizer: Optional[str] = field(
24+
default="HuggingFaceH4/mistral-7b-sft-beta",
25+
metadata={"help": "the tokenizer to use"},
26+
)
27+
ports: List[str] = field(default_factory=lambda: ["8000"], metadata={"help": "ports of the model response"})
28+
eos_ids: List[int] = field(default_factory=lambda: [], metadata={"help": "the ids of the end of sentence tokens"})
29+
dataset_name_or_path: Optional[str] = field(
30+
default="cornfieldrm/iterative-prompt-v1-iter1-2K",
31+
metadata={"help": "the location of the dataset name or path"},
32+
)
33+
output_dir: Optional[str] = field(
34+
default="uf_split0_responses_K8.jsonl",
35+
metadata={"help": "the location of the output file"},
36+
)
37+
bos_format: Optional[str] = field(
38+
default="",
39+
metadata={"help": "the format of the beginning of the sentence"},
40+
)
41+
K: Optional[int] = field(
42+
default=8,
43+
metadata={"help": "the number of generations per prompt"},
44+
)
45+
max_input_length: Optional[int] = field(
46+
default=10000,
47+
metadata={"help": "the maximum length of the input tokens"},
48+
)
49+
max_new_tokens: Optional[int] = field(
50+
default=2048,
51+
metadata={"help": "the maximum length of the new tokens"},
52+
)
53+
seed: Optional[int] = field(
54+
default=42,
55+
metadata={"help": "the random seed"},
56+
)
57+
temperature: Optional[float] = field(
58+
default=0.7,
59+
metadata={"help": "the temperature"},
60+
)
61+
use_beam_search: Optional[bool] = field(
62+
default=False,
63+
metadata={"help": "the beam search"},
64+
)
65+
dataset_key: Optional[str] = field(
66+
default="context_messages",
67+
metadata={"help": "the key of the dataset"},
68+
)
69+
max_workers: Optional[int] = field(
70+
default=1024,
71+
metadata={"help": "the number of workers"},
72+
)
73+
74+
75+
parser = HfArgumentParser(ScriptArguments)
76+
script_args = parser.parse_args_into_dataclasses()[0]
77+
ds_dir = script_args.dataset_name_or_path
78+
output_dir = script_args.output_dir
79+
K = script_args.K
80+
ports = script_args.ports
81+
82+
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer)
83+
84+
85+
def query_model(prompt, args, port):
86+
json = {
87+
**args,
88+
"prompt": prompt,
89+
}
90+
response = requests.post(url=script_args.url + ":" + str(port) + "/generate", json=json)
91+
response_json = response.json()
92+
return [response_json["text"][i][len(prompt) :] for i in range(len(response_json["text"]))]
93+
94+
95+
default_args = {
96+
"use_beam_search": script_args.use_beam_search,
97+
"n": script_args.K,
98+
"temperature": script_args.temperature,
99+
"max_tokens": script_args.max_new_tokens,
100+
"seed": script_args.seed,
101+
"top_p": 1.0,
102+
"top_k": -1,
103+
"stop_token_ids": [tokenizer.eos_token_id] + script_args.eos_ids,
104+
}
105+
106+
print(default_args)
107+
108+
ds = load_dataset(ds_dir, split="train")
109+
# load_dataset("json", data_files=ds_dir, split="train", field="instances")
110+
print(ds)
111+
112+
# use tokenizer.apply_template to apply the template to the prompt
113+
ds = ds.map(
114+
lambda x: {
115+
"prompt": tokenizer.apply_chat_template(x[script_args.dataset_key], tokenize=False, add_generation_prompt=True)
116+
}
117+
)
118+
119+
120+
with ThreadPoolExecutor(max_workers=script_args.max_workers) as executor:
121+
result = [
122+
executor.submit(query_model, ds[i]["prompt"], default_args, ports[i % len(ports)]) for i in range(len(ds))
123+
]
124+
# use tqdm to show progress
125+
for _ in tqdm(as_completed(result), total=len(result)):
126+
pass
127+
128+
responses = [r.result() for r in result]
129+
130+
131+
gathered_data = []
132+
for i in range(len(ds)):
133+
tmp_data = {"prompt": ds[i][script_args.dataset_key], "responses": responses[i]}
134+
gathered_data.append(tmp_data)
135+
136+
print("I collect ", len(gathered_data), "samples")
137+
138+
139+
with open(output_dir, 'w', encoding='utf8') as f:
140+
for i in range(len(gathered_data)):
141+
json.dump(gathered_data[i], f, ensure_ascii=False)
142+
f.write('\n')

0 commit comments

Comments
 (0)