From 971c5f2738f07dd5c783ec1d91deb6caf08a30d7 Mon Sep 17 00:00:00 2001 From: Ritvik19 Date: Thu, 22 May 2025 12:18:20 +0530 Subject: [PATCH] feat: Add support for H4 Qwen2.5-Math-7B-Instruct-PRM and Qwen2.5-7B-Instruct recipes --- recipes/Qwen2.5-7B-Instruct/beam_search.yaml | 11 ++++ recipes/Qwen2.5-7B-Instruct/best_of_n.yaml | 12 ++++ recipes/Qwen2.5-7B-Instruct/dvts.yaml | 10 ++++ src/sal/models/reward_models.py | 62 +++++++++++++++++++- 4 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 recipes/Qwen2.5-7B-Instruct/beam_search.yaml create mode 100644 recipes/Qwen2.5-7B-Instruct/best_of_n.yaml create mode 100644 recipes/Qwen2.5-7B-Instruct/dvts.yaml diff --git a/recipes/Qwen2.5-7B-Instruct/beam_search.yaml b/recipes/Qwen2.5-7B-Instruct/beam_search.yaml new file mode 100644 index 00000000..10d57290 --- /dev/null +++ b/recipes/Qwen2.5-7B-Instruct/beam_search.yaml @@ -0,0 +1,11 @@ +# refer to src/sal/config.py for more options + +model_path: Qwen/Qwen2.5-7B-Instruct +prm_path: HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2 +custom_chat_template: null +filter_duplicates: true +approach: beam_search +n: 4 +search_batch_size: 1 # DO NOT CHANGE! +num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET +seed: 0 \ No newline at end of file diff --git a/recipes/Qwen2.5-7B-Instruct/best_of_n.yaml b/recipes/Qwen2.5-7B-Instruct/best_of_n.yaml new file mode 100644 index 00000000..4836e392 --- /dev/null +++ b/recipes/Qwen2.5-7B-Instruct/best_of_n.yaml @@ -0,0 +1,12 @@ +# refer to src/sal/config.py for more options + +model_path: Qwen/Qwen2.5-7B-Instruct +prm_path: HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2 +custom_chat_template: null +approach: best_of_n +n: 4 +search_batch_size: 25 +sort_completed: true +filter_duplicates: true +num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET +seed: 0 diff --git a/recipes/Qwen2.5-7B-Instruct/dvts.yaml b/recipes/Qwen2.5-7B-Instruct/dvts.yaml new file mode 100644 index 00000000..260db0e8 --- /dev/null +++ b/recipes/Qwen2.5-7B-Instruct/dvts.yaml @@ -0,0 +1,10 @@ +# refer to src/sal/config.py for more options + +model_path: Qwen/Qwen2.5-7B-Instruct +prm_path: HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2 +custom_chat_template: null +approach: dvts +n: 4 +search_batch_size: 25 +num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET +seed: 0 diff --git a/src/sal/models/reward_models.py b/src/sal/models/reward_models.py index b78c1b15..92a98851 100644 --- a/src/sal/models/reward_models.py +++ b/src/sal/models/reward_models.py @@ -23,6 +23,7 @@ AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, + AutoModelForTokenClassification, ) from sal.config import Config @@ -399,6 +400,55 @@ def make_step_rewards(logits, token_masks): return all_scores_res +class H4PRM(PRM): + @classmethod + def _load_model_and_tokenizer( + cls, prm_model_path, **model_kwargs + ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + tokenizer = AutoTokenizer.from_pretrained(prm_model_path) + tokenizer.padding_side = "left" + model = AutoModelForTokenClassification.from_pretrained( + prm_model_path, + device_map="auto", + torch_dtype=torch.float16, + **model_kwargs, + ).eval() + return model, tokenizer + + def score( + self, questions: list[str], outputs: list[list[str]] + ) -> list[list[float]]: + separator = "\n\n" + all_scores = [] + for question, answers in zip(questions, outputs): + answer_scores = [] + for answer in answers: + steps = answer.split(separator) + step_scores = [] + for idx in range(1, len(steps) + 1): + text = separator.join([question] + steps[:idx]) + separator + inputs = self.tokenizer( + text, + return_tensors="pt", + truncation=True, + padding=True, + ).to(self.model.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs.logits + + last_token_logits = logits[0, -1] + predicted_label_id = torch.argmax(last_token_logits).item() + predicted_label = self.model.config.id2label[predicted_label_id] + + score = 1.0 if predicted_label == "LABEL_1" else 0.0 + step_scores.append(score) + answer_scores.append(step_scores) + + all_scores.append(answer_scores) + + return all_scores class SkyworkO1_1_5B(SkyworkO1): def load_model_and_tokenizer( @@ -423,6 +473,13 @@ def load_model_and_tokenizer( prm_model_path = "Qwen/Qwen2.5-Math-PRM-7B" return Qwen_2_5_Math._load_model_and_tokenizer(prm_model_path, **model_kwargs) +class H4_Qwen_2_5_Math_7B(H4PRM): + def load_model_and_tokenizer( + self, **model_kwargs + ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + prm_model_path = "HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2" + return H4PRM._load_model_and_tokenizer(prm_model_path, **model_kwargs) + def load_prm(config: Config) -> PRM: if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm": @@ -440,4 +497,7 @@ def load_prm(config: Config) -> PRM: if config.prm_path == "Qwen/Qwen2.5-Math-PRM-7B": return Qwen_2_5_Math_7B(config) - raise NotImplementedError(f"PRM {config.prm_path} not implemented") + if config.prm_path == "HuggingFaceH4/Qwen2.5-Math-7B-Instruct-PRM-0.2": + return H4_Qwen_2_5_Math_7B(config) + + raise NotImplementedError(f"PRM {config.prm_path} not implemented") \ No newline at end of file