Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions recipes/Qwen2.5-7B-Instruct/beam_search.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# refer to src/sal/config.py for more options

model_path: Qwen/Qwen2.5-7B-Instruct
prm_path: HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2
custom_chat_template: null
filter_duplicates: true
approach: beam_search
n: 4
search_batch_size: 1 # DO NOT CHANGE!
num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET
seed: 0
12 changes: 12 additions & 0 deletions recipes/Qwen2.5-7B-Instruct/best_of_n.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# refer to src/sal/config.py for more options

model_path: Qwen/Qwen2.5-7B-Instruct
prm_path: HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2
custom_chat_template: null
approach: best_of_n
n: 4
search_batch_size: 25
sort_completed: true
filter_duplicates: true
num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET
seed: 0
10 changes: 10 additions & 0 deletions recipes/Qwen2.5-7B-Instruct/dvts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# refer to src/sal/config.py for more options

model_path: Qwen/Qwen2.5-7B-Instruct
prm_path: HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2
custom_chat_template: null
approach: dvts
n: 4
search_batch_size: 25
num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET
seed: 0
62 changes: 61 additions & 1 deletion src/sal/models/reward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
AutoModelForTokenClassification,
)

from sal.config import Config
Expand Down Expand Up @@ -399,6 +400,55 @@ def make_step_rewards(logits, token_masks):

return all_scores_res

class H4PRM(PRM):
@classmethod
def _load_model_and_tokenizer(
cls, prm_model_path, **model_kwargs
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(prm_model_path)
tokenizer.padding_side = "left"
model = AutoModelForTokenClassification.from_pretrained(
prm_model_path,
device_map="auto",
torch_dtype=torch.float16,
**model_kwargs,
).eval()
return model, tokenizer

def score(
self, questions: list[str], outputs: list[list[str]]
) -> list[list[float]]:
separator = "\n\n"
all_scores = []
for question, answers in zip(questions, outputs):
answer_scores = []
for answer in answers:
steps = answer.split(separator)
step_scores = []
for idx in range(1, len(steps) + 1):
text = separator.join([question] + steps[:idx]) + separator
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
).to(self.model.device)

with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits

last_token_logits = logits[0, -1]
predicted_label_id = torch.argmax(last_token_logits).item()
predicted_label = self.model.config.id2label[predicted_label_id]

score = 1.0 if predicted_label == "LABEL_1" else 0.0
step_scores.append(score)
answer_scores.append(step_scores)

all_scores.append(answer_scores)

return all_scores

class SkyworkO1_1_5B(SkyworkO1):
def load_model_and_tokenizer(
Expand All @@ -423,6 +473,13 @@ def load_model_and_tokenizer(
prm_model_path = "Qwen/Qwen2.5-Math-PRM-7B"
return Qwen_2_5_Math._load_model_and_tokenizer(prm_model_path, **model_kwargs)

class H4_Qwen_2_5_Math_7B(H4PRM):
def load_model_and_tokenizer(
self, **model_kwargs
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
prm_model_path = "HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2"
return H4PRM._load_model_and_tokenizer(prm_model_path, **model_kwargs)


def load_prm(config: Config) -> PRM:
if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm":
Expand All @@ -440,4 +497,7 @@ def load_prm(config: Config) -> PRM:
if config.prm_path == "Qwen/Qwen2.5-Math-PRM-7B":
return Qwen_2_5_Math_7B(config)

raise NotImplementedError(f"PRM {config.prm_path} not implemented")
if config.prm_path == "HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2":
return H4_Qwen_2_5_Math_7B(config)

raise NotImplementedError(f"PRM {config.prm_path} not implemented")