From fef2f5055908f2f43a1befcd58f3b3a66fd000df Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Sun, 25 Aug 2024 20:35:56 +0000 Subject: [PATCH 01/14] fix --- src/rank_llm/rerank/pairwise/__init__.py | 0 src/rank_llm/rerank/pairwise/duot5.py | 208 ++++++++++++++++++ .../rerank/pairwise/pairwise_rankllm.py | 20 ++ 3 files changed, 228 insertions(+) create mode 100644 src/rank_llm/rerank/pairwise/__init__.py create mode 100644 src/rank_llm/rerank/pairwise/duot5.py create mode 100644 src/rank_llm/rerank/pairwise/pairwise_rankllm.py diff --git a/src/rank_llm/rerank/pairwise/__init__.py b/src/rank_llm/rerank/pairwise/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py new file mode 100644 index 00000000..8474f9b1 --- /dev/null +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -0,0 +1,208 @@ +import logging +import math +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import cmp_to_key +from typing import Dict, List, Optional, Tuple + +import torch +from tqdm import tqdm +from transformers import T5ForConditionalGeneration, T5Tokenizer +from transformers.generation import GenerationConfig + +from rank_llm.data import Candidate, Result +from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM + +logger = logging.getLogger(__name__) + + +class DuoT5(PairwiseRankLLM): + def __init__( + self, + model: str, + device: str = "cuda", + window_size: int = 20, + batched: bool = False, + ): + super.__init(model, device, window_size, batched) + self._tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-base-msmarco") + self._llm = T5ForConditionalGeneration.from_pretrained( + "castorini/duot5-base-msmarco" + ).to(self._device) + + def run_llm_batched( + self, + prompts: List[str | List[Dict[str, str]]], + current_window_size: Optional[int] = None, + ) -> List[Tuple[str, int]]: + if SamplingParams is None: + raise ImportError( + "Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference." + ) + logger.info(f"VLLM Generating!") + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=self.num_output_tokens(current_window_size), + min_tokens=self.num_output_tokens(current_window_size), + ) + outputs = self._llm.generate(prompts, sampling_params) + + return [ + (output.outputs[0].text, len(output.outputs[0].token_ids)) + for output in outputs + ] + + def run_llm( + self, prompt: str, current_window_size: Optional[int] = None + ) -> Tuple[str, int, float]: + # CHANGE THIS CODE + if current_window_size is None: + current_window_size = self._window_size + inputs = self._tokenizer([prompt]) + inputs = {k: torch.tensor(v).to(self._device) for k, v in inputs.items()} + gen_cfg = GenerationConfig.from_model_config(self._llm.config) + gen_cfg.max_new_tokens = self.num_output_tokens() + gen_cfg.min_new_tokens = self.num_output_tokens() + gen_cfg.decoder_start_token_id = None + gen_cfg.output_scores = True + gen_cfg.return_dict_in_generate = True + # gen_cfg.temperature = 0 + gen_cfg.do_sample = False + token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( + self._device + ) + output = self._llm.generate(token_prompt, generation_config=gen_cfg) + output_ids = output.sequences + logits = output.scores + + if self._llm.config.is_encoder_decoder: + output_ids = output_ids[0] + output_ids = output_ids[1:] + + outputs = self._tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + truth_logit = logits[0][0][1176] + false_logit = logits[0][0][6136] + score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) + # print(outputs, output_ids.size(0)) + return outputs, output_ids.size(0), score + + def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: + return 1 + + def _add_prefix_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_post_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_few_shot_examples(self, conv): + return 1 + # unused for now + + def create_prompt( + self, result: Result, rank_start: int, rank_end: int + ) -> Tuple[str, int]: + # query = result.query.text + # query = self._replace_number(query) + # input = f"Query: {query} Document: {result.candidates[rank_start].doc['contents']}" + # prompt = self._tokenizer.decode(self._tokenizer.encode(input)[:480])[:-4] + " Relevant: " + # prompt = prompt.replace("","") + + # CHANGE THIS CODE + query = result.query.text + query = self._replace_number(query) + doc1 = result.candidates[rank_start].doc["contents"] + doc2 = result.candidates[rank_end].doc["contents"] + doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4] + doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4] + prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:" + prompt = prompt.replace("", "") + return prompt, self.get_num_tokens(prompt) + + def create_prompt_batched( + self, + results: List[Result], + rank_start: int, + rank_end: int, + batch_size: int = 32, + ) -> List[Tuple[str, int]]: + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + all_completed_prompts = [] + + with ThreadPoolExecutor() as executor: + for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): + futures = [ + executor.submit(self.create_prompt, result, rank_start, rank_end) + for result in batch + ] + completed_prompts = [ + future.result() for future in as_completed(futures) + ] + all_completed_prompts.extend(completed_prompts) + return all_completed_prompts + + def get_num_tokens(self, prompt: str) -> int: + return len(self._tokenizer.encode(prompt)) + + def cost_per_1k_token(self, input_token: bool) -> float: + return 0 + + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: + if x.score < y.score: + return -1 + elif x.score > y.score: + return 1 + else: + return 0 + + def permutation_pipeline( + self, + result: Result, + rank_start: int, + rank_end: int, + logging: bool = False, + ) -> Result: + """ + Runs the permutation pipeline on the passed in result set within the passed in rank range. + + Args: + result (Result): The result object to process. + rank_start (int): The start index for ranking. + rank_end (int): The end index for ranking. + logging (bool, optional): Flag to enable logging of operations. Defaults to False. + + Returns: + Result: The processed result object after applying permutation. + """ + # CHANGE THIS CODE + # print(len(result.candidates)) + # for i in range (len(result.candidates)): + # prompt, num_tokens = self.create_prompt(result, i, rank_end) + # output, output_num_tokens, score = self.run_llm(prompt=prompt) + # (result.candidates[i]).score = score + + # result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) + n = len(result.candidates) + scores = [0 for _ in range(n)] + for i in range(n): + for j in range(n): + if j == i: + continue + else: + prompt1, num_tokens1 = self.create_prompt(result, i, j) + prompt2, num_tokens2 = self.create_prompt(result, j, i) + _, _, pi_j = self.run_llm(prompt=prompt1) + _, _, pj_i = self.run_llm(prompt=prompt2) + scores[i] = scores[i] + pi_j + 1 - pj_i + + for i in range(n): + (result.candidates[i]).score = scores[i] + + result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) + + return result diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py new file mode 100644 index 00000000..90242c7e --- /dev/null +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -0,0 +1,20 @@ +import logging +from abc import ABC + +from rank_llm.rerank.rankllm import RankLLM + +logger = logging.getLogger(__name__) + + +class PairwiseRankLLM(RankLLM, ABC): + def __init__( + self, + model: str, + device: str = "cuda", + window_size: int = 20, + batched: bool = False, + ) -> None: + super.__init__(model) + self._window_size = window_size + self._device = device + self._batched = batched From 9cb96123f459a7f182d13fcfda2d33f332b880aa Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Sun, 25 Aug 2024 21:09:38 +0000 Subject: [PATCH 02/14] inits, cleanup --- .../rerank/listwise/listwise_rankllm.py | 2 + src/rank_llm/rerank/pairwise/__init__.py | 3 + src/rank_llm/rerank/pairwise/duot5.py | 72 ++++++++-------- .../rerank/pairwise/pairwise_rankllm.py | 37 ++++++--- src/rank_llm/rerank/pointwise/__init__.py | 4 +- .../rerank/pointwise/pointwise_rankllm.py | 25 ++++-- src/rank_llm/rerank/rankllm.py | 82 +++++++++++-------- 7 files changed, 132 insertions(+), 93 deletions(-) diff --git a/src/rank_llm/rerank/listwise/listwise_rankllm.py b/src/rank_llm/rerank/listwise/listwise_rankllm.py index 96495df8..09f0e0b0 100644 --- a/src/rank_llm/rerank/listwise/listwise_rankllm.py +++ b/src/rank_llm/rerank/listwise/listwise_rankllm.py @@ -17,6 +17,8 @@ class ListwiseRankLLM(RankLLM, ABC): """ + Abstract base class that all listwise rerankers inherit. + All children of ListwiseRankLLM must implement these functions: - rerank_batched - run_llm_batched diff --git a/src/rank_llm/rerank/pairwise/__init__.py b/src/rank_llm/rerank/pairwise/__init__.py index e69de29b..99042941 100644 --- a/src/rank_llm/rerank/pairwise/__init__.py +++ b/src/rank_llm/rerank/pairwise/__init__.py @@ -0,0 +1,3 @@ +from .duot5 import DuoT5 + +__all__ = ["DuoT5"] diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index 8474f9b1..e77f14a9 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -2,15 +2,16 @@ import math from concurrent.futures import ThreadPoolExecutor, as_completed from functools import cmp_to_key -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from tqdm import tqdm from transformers import T5ForConditionalGeneration, T5Tokenizer from transformers.generation import GenerationConfig -from rank_llm.data import Candidate, Result +from rank_llm.data import Candidate, Request, Result from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM +from rank_llm.rerank.rankllm import PromptMode logger = logging.getLogger(__name__) @@ -19,37 +20,32 @@ class DuoT5(PairwiseRankLLM): def __init__( self, model: str, - device: str = "cuda", - window_size: int = 20, - batched: bool = False, + context_size: int, + prompt_mode: PromptMode, ): - super.__init(model, device, window_size, batched) + super.__init__(model, context_size, prompt_mode) self._tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-base-msmarco") self._llm = T5ForConditionalGeneration.from_pretrained( "castorini/duot5-base-msmarco" ).to(self._device) - def run_llm_batched( + # TODO + def rerank_batch( self, - prompts: List[str | List[Dict[str, str]]], - current_window_size: Optional[int] = None, - ) -> List[Tuple[str, int]]: - if SamplingParams is None: - raise ImportError( - "Please install rank-llm with `pip install rank-llm[vllm]` to use batch inference." - ) - logger.info(f"VLLM Generating!") - sampling_params = SamplingParams( - temperature=0.0, - max_tokens=self.num_output_tokens(current_window_size), - min_tokens=self.num_output_tokens(current_window_size), - ) - outputs = self._llm.generate(prompts, sampling_params) + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: logging.Any, + ) -> List[Result]: + return - return [ - (output.outputs[0].text, len(output.outputs[0].token_ids)) - for output in outputs - ] + # TODO + def run_llm_batched( + self, prompts: List[str | List[torch.Dict[str, str]]], **kwargs + ) -> List[Tuple[str | int]]: + return def run_llm( self, prompt: str, current_window_size: Optional[int] = None @@ -87,19 +83,6 @@ def run_llm( # print(outputs, output_ids.size(0)) return outputs, output_ids.size(0), score - def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: - return 1 - - def _add_prefix_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_post_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_few_shot_examples(self, conv): - return 1 - # unused for now - def create_prompt( self, result: Result, rank_start: int, rank_end: int ) -> Tuple[str, int]: @@ -152,6 +135,9 @@ def get_num_tokens(self, prompt: str) -> int: def cost_per_1k_token(self, input_token: bool) -> float: return 0 + def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: + return 1 + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: if x.score < y.score: return -1 @@ -160,6 +146,16 @@ def candidate_comparator(self, x: Candidate, y: Candidate) -> int: else: return 0 + def _add_prefix_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_post_prompt(self, query: str, num: int) -> str: + return f"Given the query: {query}, output its relevance to the {num} documents." + + def _add_few_shot_examples(self, conv): + return 1 + # unused for now + def permutation_pipeline( self, result: Result, diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 90242c7e..38d7da27 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -1,20 +1,35 @@ import logging from abc import ABC -from rank_llm.rerank.rankllm import RankLLM +from rank_llm.rerank.rankllm import PromptMode, RankLLM logger = logging.getLogger(__name__) class PairwiseRankLLM(RankLLM, ABC): - def __init__( + """ + Abstract base class that all pairwise rerankers implement. + + All concrete children of RankLLM must implement these functions: + - rerank_batch + - run_llm_batched + - run_llm + - create_prompt_batched + - create_prompt + - get_num_tokens + - cost_per_1k_tokens + - num_output_tokens + """ + + def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: + super.__init__(model, context_size, prompt_mode) + + # TODO + def get_output_filename( self, - model: str, - device: str = "cuda", - window_size: int = 20, - batched: bool = False, - ) -> None: - super.__init__(model) - self._window_size = window_size - self._device = device - self._batched = batched + top_k_candidates: int, + dataset_name: str, + shuffle_candidates: bool, + **kwargs: logging.Any, + ) -> str: + return diff --git a/src/rank_llm/rerank/pointwise/__init__.py b/src/rank_llm/rerank/pointwise/__init__.py index 2d622682..28403f7b 100644 --- a/src/rank_llm/rerank/pointwise/__init__.py +++ b/src/rank_llm/rerank/pointwise/__init__.py @@ -1,3 +1,3 @@ -from .pointwise_rankllm import PointwiseRankLLM +from .monot5 import MonoT5 -__all__ = ["PointwiseRankLLM"] +__all__ = ["MonoT5"] diff --git a/src/rank_llm/rerank/pointwise/pointwise_rankllm.py b/src/rank_llm/rerank/pointwise/pointwise_rankllm.py index 40178d7a..674e6857 100644 --- a/src/rank_llm/rerank/pointwise/pointwise_rankllm.py +++ b/src/rank_llm/rerank/pointwise/pointwise_rankllm.py @@ -18,8 +18,15 @@ class PointwiseRankLLM(RankLLM, ABC): """ + Abstract base class that all pointwise rerankers implement. + All children of PointwiseRankLLM must implement these functions: - - currently all abstract functions of RankLLM + - run_llm_batched + - run_llm + - create_prompt + - get_num_tokens + - cost_per_1k_tokens + - num_output_tokens """ @@ -114,14 +121,6 @@ def create_prompt_batched( return prompts, token_counts - def candidate_comparator(self, x: Candidate, y: Candidate) -> int: - if x.score < y.score: - return -1 - elif x.score > y.score: - return 1 - else: - return 0 - def get_output_filename( self, top_k_candidates: int, @@ -151,6 +150,14 @@ def get_output_filename( else f"{name}_{datetime.isoformat(datetime.now())}" ) + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: + if x.score < y.score: + return -1 + elif x.score > y.score: + return 1 + else: + return 0 + def _replace_number(self, s: str) -> str: return re.sub(r"\[(\d+)\]", r"(\1)", s) diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 7df7b1c8..4b3e911d 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -21,11 +21,60 @@ def __str__(self): class RankLLM(ABC): + """ + Abstract base class that all rerankers inherit. + + All concrete children of RankLLM must implement these functions: + - rerank_batch + - run_llm_batched + - run_llm + - create_prompt_batched + - create_prompt + - get_num_tokens + - cost_per_1k_tokens + - num_output_tokens + - get_output_filename + + """ + def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: self._model = model self._context_size = context_size self._prompt_mode = prompt_mode + @abstractmethod + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: Any, + ) -> List[Result]: + """ + Reranks a list of requests using the RankLLM agent. + + This function applies a sliding window algorithm to rerank the results. + Each window of results is processed by the RankLLM agent to obtain a new ranking. + + Args: + requests (List[Request]): The list of requests. Each request has a query and a candidates list. + rank_start (int, optional): The starting rank for processing. Defaults to 0. + rank_end (int, optional): The end rank for processing. Defaults to 100. + window_size (int, optional): The size of each sliding window. Defaults to 20. + step (int, optional): The step size for moving the window. Defaults to 10. + shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. + logging (bool, optional): Enables logging of the reranking process. Defaults to False. + vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. + populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. + batched (bool, optional): Whether to use batched processing. Defaults to False. + + Returns: + List[Result]: A list containing the reranked candidates. + """ + pass + @abstractmethod def run_llm_batched( self, prompts: List[Union[str, List[Dict[str, str]]]], **kwargs @@ -126,39 +175,6 @@ def num_output_tokens(self) -> int: """ pass - @abstractmethod - def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - shuffle_candidates: bool = False, - logging: bool = False, - **kwargs: Any, - ) -> List[Result]: - """ - Reranks a list of requests using the RankLLM agent. - - This function applies a sliding window algorithm to rerank the results. - Each window of results is processed by the RankLLM agent to obtain a new ranking. - - Args: - requests (List[Request]): The list of requests. Each request has a query and a candidates list. - rank_start (int, optional): The starting rank for processing. Defaults to 0. - rank_end (int, optional): The end rank for processing. Defaults to 100. - window_size (int, optional): The size of each sliding window. Defaults to 20. - step (int, optional): The step size for moving the window. Defaults to 10. - shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. - logging (bool, optional): Enables logging of the reranking process. Defaults to False. - vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. - populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. - batched (bool, optional): Whether to use batched processing. Defaults to False. - - Returns: - List[Result]: A list containing the reranked candidates. - """ - pass - @abstractmethod def get_output_filename( self, From 912de1aa29d37ca312a52ceef9c063acba612362 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Thu, 29 Aug 2024 13:58:33 -0400 Subject: [PATCH 03/14] duot5 and pairwise implementation --- src/rank_llm/rerank/pairwise/duot5.py | 224 ++++++------------ .../rerank/pairwise/pairwise_rankllm.py | 173 +++++++++++++- src/rank_llm/rerank/rankllm.py | 83 +++---- src/rank_llm/rerank/reranker.py | 27 +++ src/rank_llm/retrieve_and_rerank.py | 4 +- 5 files changed, 307 insertions(+), 204 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index e77f14a9..5c901318 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -1,17 +1,12 @@ import logging import math -from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import cmp_to_key -from typing import List, Optional, Tuple +from typing import List, Tuple -import torch -from tqdm import tqdm from transformers import T5ForConditionalGeneration, T5Tokenizer from transformers.generation import GenerationConfig -from rank_llm.data import Candidate, Request, Result +from rank_llm.data import Result from rank_llm.rerank.pairwise.pairwise_rankllm import PairwiseRankLLM -from rank_llm.rerank.rankllm import PromptMode logger = logging.getLogger(__name__) @@ -20,49 +15,81 @@ class DuoT5(PairwiseRankLLM): def __init__( self, model: str, - context_size: int, - prompt_mode: PromptMode, + prompt_mode: str = "duot5", + context_size: int = 512, + device: str = "cuda", + batch_size: int = 32, ): - super.__init__(model, context_size, prompt_mode) - self._tokenizer = T5Tokenizer.from_pretrained("castorini/duot5-base-msmarco") - self._llm = T5ForConditionalGeneration.from_pretrained( - "castorini/duot5-base-msmarco" - ).to(self._device) + super().__init__( + model=model, + context_size=context_size, + prompt_mode=prompt_mode, + device=device, + batch_size=batch_size, + ) + + self._tokenizer = T5Tokenizer.from_pretrained(model) + self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) + self._context_size = context_size - # TODO - def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - shuffle_candidates: bool = False, - logging: bool = False, - **kwargs: logging.Any, - ) -> List[Result]: - return - - # TODO def run_llm_batched( - self, prompts: List[str | List[torch.Dict[str, str]]], **kwargs - ) -> List[Tuple[str | int]]: - return - - def run_llm( - self, prompt: str, current_window_size: Optional[int] = None - ) -> Tuple[str, int, float]: - # CHANGE THIS CODE - if current_window_size is None: - current_window_size = self._window_size - inputs = self._tokenizer([prompt]) - inputs = {k: torch.tensor(v).to(self._device) for k, v in inputs.items()} + self, + prompts: List[str], + ) -> Tuple[List[str], List[int], List[float]]: gen_cfg = GenerationConfig.from_model_config(self._llm.config) gen_cfg.max_new_tokens = self.num_output_tokens() gen_cfg.min_new_tokens = self.num_output_tokens() - gen_cfg.decoder_start_token_id = None gen_cfg.output_scores = True gen_cfg.return_dict_in_generate = True - # gen_cfg.temperature = 0 gen_cfg.do_sample = False + + all_outputs = [] + all_output_token_counts = [] + all_scores = [] + + batch_prompts = prompts + + token_prompts = self._tokenizer( + batch_prompts, padding=True, truncation=True, return_tensors="pt" + ).to(self._device) + + token_prompts = token_prompts["input_ids"] + + batch_outputs = self._llm.generate(token_prompts, generation_config=gen_cfg) + + batch_output_ids = batch_outputs.sequences + batch_logits = batch_outputs.scores + + batch_outputs = [ + self._tokenizer.decode( + single_token_sequence, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) + for single_token_sequence in batch_output_ids + ] + + for logit_tensor in batch_logits[0]: + truth_logit = logit_tensor[1176] + false_logit = logit_tensor[6136] + score = math.exp(truth_logit) / ( + math.exp(truth_logit) + math.exp(false_logit) + ) + all_scores.append(score) + all_output_token_counts.append(self.num_output_tokens) + + all_outputs.extend(batch_outputs) + + return all_outputs, all_output_token_counts, all_scores + + def run_llm(self, prompt: str) -> Tuple[str, int, float]: + gen_cfg = GenerationConfig.from_model_config(self._llm.config) + gen_cfg.max_new_tokens = self.num_output_tokens() + gen_cfg.min_new_tokens = self.num_output_tokens() + gen_cfg.output_scores = True + gen_cfg.return_dict_in_generate = True + gen_cfg.do_sample = False + token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( self._device ) @@ -80,125 +107,26 @@ def run_llm( truth_logit = logits[0][0][1176] false_logit = logits[0][0][6136] score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) - # print(outputs, output_ids.size(0)) + return outputs, output_ids.size(0), score - def create_prompt( - self, result: Result, rank_start: int, rank_end: int - ) -> Tuple[str, int]: - # query = result.query.text - # query = self._replace_number(query) - # input = f"Query: {query} Document: {result.candidates[rank_start].doc['contents']}" - # prompt = self._tokenizer.decode(self._tokenizer.encode(input)[:480])[:-4] + " Relevant: " - # prompt = prompt.replace("","") + def num_output_tokens(self) -> int: + return 1 - # CHANGE THIS CODE + def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]: query = result.query.text query = self._replace_number(query) - doc1 = result.candidates[rank_start].doc["contents"] - doc2 = result.candidates[rank_end].doc["contents"] + doc1 = self.convert_doc_to_prompt_content(result.candidates[index1].doc, max_length=self._context_size) + doc2 = self.convert_doc_to_prompt_content(result.candidates[index2].doc, max_length=self._context_size) doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4] doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4] prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:" - prompt = prompt.replace("", "") + prompt = prompt.replace("","") + return prompt, self.get_num_tokens(prompt) - def create_prompt_batched( - self, - results: List[Result], - rank_start: int, - rank_end: int, - batch_size: int = 32, - ) -> List[Tuple[str, int]]: - def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - all_completed_prompts = [] - - with ThreadPoolExecutor() as executor: - for batch in tqdm(chunks(results, batch_size), desc="Processing batches"): - futures = [ - executor.submit(self.create_prompt, result, rank_start, rank_end) - for result in batch - ] - completed_prompts = [ - future.result() for future in as_completed(futures) - ] - all_completed_prompts.extend(completed_prompts) - return all_completed_prompts - def get_num_tokens(self, prompt: str) -> int: return len(self._tokenizer.encode(prompt)) def cost_per_1k_token(self, input_token: bool) -> float: return 0 - - def num_output_tokens(self, current_window_size: Optional[int] = None) -> int: - return 1 - - def candidate_comparator(self, x: Candidate, y: Candidate) -> int: - if x.score < y.score: - return -1 - elif x.score > y.score: - return 1 - else: - return 0 - - def _add_prefix_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_post_prompt(self, query: str, num: int) -> str: - return f"Given the query: {query}, output its relevance to the {num} documents." - - def _add_few_shot_examples(self, conv): - return 1 - # unused for now - - def permutation_pipeline( - self, - result: Result, - rank_start: int, - rank_end: int, - logging: bool = False, - ) -> Result: - """ - Runs the permutation pipeline on the passed in result set within the passed in rank range. - - Args: - result (Result): The result object to process. - rank_start (int): The start index for ranking. - rank_end (int): The end index for ranking. - logging (bool, optional): Flag to enable logging of operations. Defaults to False. - - Returns: - Result: The processed result object after applying permutation. - """ - # CHANGE THIS CODE - # print(len(result.candidates)) - # for i in range (len(result.candidates)): - # prompt, num_tokens = self.create_prompt(result, i, rank_end) - # output, output_num_tokens, score = self.run_llm(prompt=prompt) - # (result.candidates[i]).score = score - - # result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) - n = len(result.candidates) - scores = [0 for _ in range(n)] - for i in range(n): - for j in range(n): - if j == i: - continue - else: - prompt1, num_tokens1 = self.create_prompt(result, i, j) - prompt2, num_tokens2 = self.create_prompt(result, j, i) - _, _, pi_j = self.run_llm(prompt=prompt1) - _, _, pj_i = self.run_llm(prompt=prompt2) - scores[i] = scores[i] + pi_j + 1 - pj_i - - for i in range(n): - (result.candidates[i]).score = scores[i] - - result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) - - return result diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 38d7da27..5f4913ec 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -1,11 +1,20 @@ +import copy import logging +import math +import re from abc import ABC +from datetime import datetime +from functools import cmp_to_key +from typing import Any, Dict, List, Tuple +from ftfy import fix_text +from tqdm import tqdm + +from rank_llm.data import Candidate, Request, Result from rank_llm.rerank.rankllm import PromptMode, RankLLM logger = logging.getLogger(__name__) - class PairwiseRankLLM(RankLLM, ABC): """ Abstract base class that all pairwise rerankers implement. @@ -21,15 +30,167 @@ class PairwiseRankLLM(RankLLM, ABC): - num_output_tokens """ - def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: - super.__init__(model, context_size, prompt_mode) + def __init__( + self, + model: str, + context_size: int, + prompt_mode: PromptMode, + device: str = "cuda", + filename: str = "", + batch_size: int = 32, + ) -> None: + super().__init__(model, context_size, prompt_mode) + self._device = device + self._filename = filename + self._batch_size = batch_size + + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: Any, + ) -> List[Result]: + rerank_results = [ + Result( + query=copy.deepcopy(request.query), + candidates=copy.deepcopy(request.candidates), + ranking_exec_summary=[], + ) + for request in requests + ] + + for result in rerank_results: + for i in result.candidates: + i.score = 0 + + + end = len(rerank_results[0].candidates) * len(rerank_results[0].candidates) * len(requests) + with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: + for index in range(0, end, self._batch_size): + prompts, token_counts = self.create_prompt_batched( + results=rerank_results, index=index + ) + + outputs, output_tokens, scores = self.run_llm_batched(prompts=prompts) + + for update_index in range ( + index, + min( + index + self._batch_size, + end + ) + ): + query_number = math.floor( + update_index / (len(rerank_results[0].candidates) ** 2) + ) + candidate_1 = math.floor( + (update_index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) + ) + candidate_2 = update_index % len(rerank_results[0].candidates) + rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] + rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] + + if index + self._batch_size > end: + progress_bar.update(end - index) + else: + progress_bar.update(self._batch_size) + + + for result in rerank_results: + result.candidates.sort( + key=cmp_to_key(self.candidate_comparator), reverse=True + ) + + return rerank_results + + def create_prompt_batched( + self, results: List[Result], index + ) -> Tuple[List[str], List[int]]: + prompts = [] + token_counts = [] + + for index in range( + index, + min(index + self._batch_size, len(results[0].candidates) * len(results)), + ): + query_number = math.floor( + index / (len(results[0].candidates) ** 2) + ) + candidate_1 = math.floor( + (index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) + ) + candidate_2 = index % len(results[0].candidates) + + prompt, token_count = self.create_prompt( + result=results[query_number], index1=candidate_1, index2=candidate_2 + ) + + prompts.append(prompt) + token_counts.append(token_count) + return prompts, token_counts - # TODO def get_output_filename( self, top_k_candidates: int, dataset_name: str, shuffle_candidates: bool, - **kwargs: logging.Any, + **kwargs: Any, + ) -> str: + if self._filename != "": + return self._filename + _modelname = self._model.split("/")[-1] + if _modelname.startswith("checkpoint"): + _modelname = self._model.split("/")[-2] + "_" + _modelname + name = ( + f"{_modelname}_{self._context_size}_{top_k_candidates}_{self._prompt_mode}" + ) + if dataset_name: + name = f"{name}_{dataset_name}" + + if shuffle_candidates: + self._filename = f"{name}_shuffled_{datetime.isoformat(datetime.now())}" + else: + self._filename = f"{name}_{datetime.isoformat(datetime.now())}" + + return ( + f"{name}_shuffled_{datetime.isoformat(datetime.now())}" + if shuffle_candidates + else f"{name}_{datetime.isoformat(datetime.now())}" + ) + + def candidate_comparator(self, x: Candidate, y: Candidate) -> int: + if x.score < y.score: + return -1 + elif x.score > y.score: + return 1 + else: + return 0 + + def _replace_number(self, s: str) -> str: + return re.sub(r"\[(\d+)\]", r"(\1)", s) + + def convert_doc_to_prompt_content( + self, doc: Dict[str, Any], max_length: int ) -> str: - return + if "text" in doc: + content = doc["text"] + elif "segment" in doc: + content = doc["segment"] + elif "contents" in doc: + content = doc["contents"] + elif "content" in doc: + content = doc["content"] + elif "body" in doc: + content = doc["body"] + else: + content = doc["passage"] + if "title" in doc and doc["title"]: + content = "Title: " + doc["title"] + " " + "Content: " + content + content = content.strip() + content = fix_text(content) + # For Japanese should cut by character: content = content[:int(max_length)] + content = " ".join(content.split()[: int(max_length)]) + return self._replace_number(content) diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 4b3e911d..0071e7a7 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -14,6 +14,7 @@ class PromptMode(Enum): RANK_GPT_APEER = "rank_GPT_APEER" LRL = "LRL" MONOT5 = "monot5" + DUOT5 = "duot5" LiT5 = "LiT5" def __str__(self): @@ -21,60 +22,11 @@ def __str__(self): class RankLLM(ABC): - """ - Abstract base class that all rerankers inherit. - - All concrete children of RankLLM must implement these functions: - - rerank_batch - - run_llm_batched - - run_llm - - create_prompt_batched - - create_prompt - - get_num_tokens - - cost_per_1k_tokens - - num_output_tokens - - get_output_filename - - """ - def __init__(self, model: str, context_size: int, prompt_mode: PromptMode) -> None: self._model = model self._context_size = context_size self._prompt_mode = prompt_mode - @abstractmethod - def rerank_batch( - self, - requests: List[Request], - rank_start: int = 0, - rank_end: int = 100, - shuffle_candidates: bool = False, - logging: bool = False, - **kwargs: Any, - ) -> List[Result]: - """ - Reranks a list of requests using the RankLLM agent. - - This function applies a sliding window algorithm to rerank the results. - Each window of results is processed by the RankLLM agent to obtain a new ranking. - - Args: - requests (List[Request]): The list of requests. Each request has a query and a candidates list. - rank_start (int, optional): The starting rank for processing. Defaults to 0. - rank_end (int, optional): The end rank for processing. Defaults to 100. - window_size (int, optional): The size of each sliding window. Defaults to 20. - step (int, optional): The step size for moving the window. Defaults to 10. - shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. - logging (bool, optional): Enables logging of the reranking process. Defaults to False. - vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. - populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. - batched (bool, optional): Whether to use batched processing. Defaults to False. - - Returns: - List[Result]: A list containing the reranked candidates. - """ - pass - @abstractmethod def run_llm_batched( self, prompts: List[Union[str, List[Dict[str, str]]]], **kwargs @@ -175,6 +127,39 @@ def num_output_tokens(self) -> int: """ pass + @abstractmethod + def rerank_batch( + self, + requests: List[Request], + rank_start: int = 0, + rank_end: int = 100, + shuffle_candidates: bool = False, + logging: bool = False, + **kwargs: Any, + ) -> List[Result]: + """ + Reranks a list of requests using the RankLLM agent. + + This function applies a sliding window algorithm to rerank the results. + Each window of results is processed by the RankLLM agent to obtain a new ranking. + + Args: + requests (List[Request]): The list of requests. Each request has a query and a candidates list. + rank_start (int, optional): The starting rank for processing. Defaults to 0. + rank_end (int, optional): The end rank for processing. Defaults to 100. + window_size (int, optional): The size of each sliding window. Defaults to 20. + step (int, optional): The step size for moving the window. Defaults to 10. + shuffle_candidates (bool, optional): Whether to shuffle candidates before reranking. Defaults to False. + logging (bool, optional): Enables logging of the reranking process. Defaults to False. + vllm_batched (bool, optional): Whether to use VLLM batched processing. Defaults to False. + populate_exec_summary (bool, optional): Whether to populate the exec summary. Defaults to False. + batched (bool, optional): Whether to use batched processing. Defaults to False. + + Returns: + List[Result]: A list containing the reranked candidates. + """ + pass + @abstractmethod def get_output_filename( self, diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 21fa3f1d..446b033a 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -11,6 +11,7 @@ from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeOpenai from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore from rank_llm.rerank.pointwise.monot5 import MonoT5 +from rank_llm.rerank.pairwise.duot5 import DuoT5 from rank_llm.rerank.rankllm import RankLLM @@ -282,7 +283,33 @@ def create_agent( device=device, batch_size=batch_size, ) + elif "duot5" in model_path: + # using monot5 + print(f"Loading {model_path} ...") + + model_full_paths = {"duot5": "castorini/duot5-3b-med-msmarco"} + + keys_and_defaults = [ + ("prompt_mode", PromptMode.DUOT5), + ("context_size", 512), + ("device", "cuda"), + ("batch_size", 64), + ] + [prompt_mode, context_size, device, batch_size] = extract_kwargs( + keys_and_defaults, **kwargs + ) + agent = DuoT5( + model=( + model_full_paths[model_path] + if model_path in model_full_paths + else model_path + ), + prompt_mode=prompt_mode, + context_size=context_size, + device=device, + batch_size=batch_size, + ) elif "lit5-distill" in model_path.lower(): keys_and_defaults = [ ("context_size", 150), diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 992985dd..0d574fe6 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -53,7 +53,9 @@ def retrieve_and_rerank( dataset=dataset, **kwargs, ) - + print(top_k_retrieve) + print("Number of candidates per query: ") + print(len(requests[0].candidates)) # Reranking stages print(f"Reranking and returning {top_k_rerank} passages with {model_path}...") if reranker.get_agent() is None: From b1d7eff3df5fcc6ca0a1f6d68c339142cc7dca4e Mon Sep 17 00:00:00 2001 From: IR3KT4FUNZ Date: Mon, 2 Sep 2024 16:45:51 -0400 Subject: [PATCH 04/14] implementation for pairwise and duot5 --- src/rank_llm/rerank/pairwise/pairwise_rankllm.py | 10 ++++++++-- src/rank_llm/rerank/reranker.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 5f4913ec..f5202d05 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -53,6 +53,11 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: + self._enumerated_indices = [] + + for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): + self._enumerated_indices.append(index) + rerank_results = [ Result( query=copy.deepcopy(request.query), @@ -66,8 +71,7 @@ def rerank_batch( for i in result.candidates: i.score = 0 - - end = len(rerank_results[0].candidates) * len(rerank_results[0].candidates) * len(requests) + end = len(rerank_results[0].candidates - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: for index in range(0, end, self._batch_size): prompts, token_counts = self.create_prompt_batched( @@ -83,6 +87,7 @@ def rerank_batch( end ) ): + update_index = self._enumerated_indices[update_index] query_number = math.floor( update_index / (len(rerank_results[0].candidates) ** 2) ) @@ -116,6 +121,7 @@ def create_prompt_batched( index, min(index + self._batch_size, len(results[0].candidates) * len(results)), ): + index = self._enumerated_indices[index] query_number = math.floor( index / (len(results[0].candidates) ** 2) ) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 446b033a..a7863660 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -294,6 +294,7 @@ def create_agent( ("context_size", 512), ("device", "cuda"), ("batch_size", 64), + ("interactive", True) ] [prompt_mode, context_size, device, batch_size] = extract_kwargs( keys_and_defaults, **kwargs From b35f2521767005276f8e15dcefa462178cd8fe0a Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 2 Sep 2024 18:43:14 -0400 Subject: [PATCH 05/14] duot5 bug fixes, add fix for retrieving <100 candidates in non-interactive cases --- .../rerank/pairwise/pairwise_rankllm.py | 23 ++++++++++--------- src/rank_llm/rerank/reranker.py | 1 - src/rank_llm/retrieve_and_rerank.py | 7 +++--- src/rank_llm/scripts/run_rank_llm.py | 8 +++++++ 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index f5202d05..a6e1d8ba 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -71,7 +71,7 @@ def rerank_batch( for i in result.candidates: i.score = 0 - end = len(rerank_results[0].candidates - 1) * len(rerank_results[0].candidates) * len(requests) + end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: for index in range(0, end, self._batch_size): prompts, token_counts = self.create_prompt_batched( @@ -87,14 +87,15 @@ def rerank_batch( end ) ): - update_index = self._enumerated_indices[update_index] + update_index_copy = self._enumerated_indices[update_index] query_number = math.floor( - update_index / (len(rerank_results[0].candidates) ** 2) + update_index_copy / (len(rerank_results[0].candidates) ** 2) ) candidate_1 = math.floor( - (update_index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) + (update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) ) - candidate_2 = update_index % len(rerank_results[0].candidates) + candidate_2 = update_index_copy % len(rerank_results[0].candidates) + rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] @@ -117,18 +118,18 @@ def create_prompt_batched( prompts = [] token_counts = [] - for index in range( + for current_index in range( index, - min(index + self._batch_size, len(results[0].candidates) * len(results)), + min(index + self._batch_size, len(results[0].candidates) * (len(results[0].candidates) - 1) * len(results)), ): - index = self._enumerated_indices[index] + current_index = self._enumerated_indices[current_index] query_number = math.floor( - index / (len(results[0].candidates) ** 2) + current_index / (len(results[0].candidates) ** 2) ) candidate_1 = math.floor( - (index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) + (current_index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) ) - candidate_2 = index % len(results[0].candidates) + candidate_2 = current_index % len(results[0].candidates) prompt, token_count = self.create_prompt( result=results[query_number], index1=candidate_1, index2=candidate_2 diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index a7863660..446b033a 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -294,7 +294,6 @@ def create_agent( ("context_size", 512), ("device", "cuda"), ("batch_size", 64), - ("interactive", True) ] [prompt_mode, context_size, device, batch_size] = extract_kwargs( keys_and_defaults, **kwargs diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 0d574fe6..cb904953 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -53,9 +53,10 @@ def retrieve_and_rerank( dataset=dataset, **kwargs, ) - print(top_k_retrieve) - print("Number of candidates per query: ") - print(len(requests[0].candidates)) + + for request in requests: + request.candidates = request.candidates[:top_k_retrieve] + # Reranking stages print(f"Reranking and returning {top_k_rerank} passages with {model_path}...") if reranker.get_agent() is None: diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 3639525b..1ce04192 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -38,6 +38,7 @@ def main(args): window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched + interactive = args.interactive _ = retrieve_and_rerank( model_path=model_path, @@ -62,6 +63,7 @@ def main(args): step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, + interactive=interactive, ) @@ -175,5 +177,11 @@ def main(args): action="store_true", help="whether to run the model in batches", ) + parser.add_argument( + "--interactive", + type=bool, + default=False, + help="whether retrieval is done from the server or a prebuilt index" + ) args = parser.parse_args() main(args) From 0950167458e9220397c34eda2f330b0a75b464b1 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Mon, 2 Sep 2024 18:50:26 -0400 Subject: [PATCH 06/14] remove temporarily unnecessary interactive argument --- src/rank_llm/scripts/run_rank_llm.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/rank_llm/scripts/run_rank_llm.py b/src/rank_llm/scripts/run_rank_llm.py index 1ce04192..3639525b 100644 --- a/src/rank_llm/scripts/run_rank_llm.py +++ b/src/rank_llm/scripts/run_rank_llm.py @@ -38,7 +38,6 @@ def main(args): window_size = args.window_size system_message = args.system_message vllm_batched = args.vllm_batched - interactive = args.interactive _ = retrieve_and_rerank( model_path=model_path, @@ -63,7 +62,6 @@ def main(args): step_size=step_size, system_message=system_message, vllm_batched=vllm_batched, - interactive=interactive, ) @@ -177,11 +175,5 @@ def main(args): action="store_true", help="whether to run the model in batches", ) - parser.add_argument( - "--interactive", - type=bool, - default=False, - help="whether retrieval is done from the server or a prebuilt index" - ) args = parser.parse_args() main(args) From 5604a610a5ce301bf948af5dd0356eaca94518b8 Mon Sep 17 00:00:00 2001 From: Eric Wang Date: Sat, 7 Sep 2024 19:40:25 -0400 Subject: [PATCH 07/14] fix enumeration bug --- src/rank_llm/rerank/pairwise/pairwise_rankllm.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index a6e1d8ba..8f47f8ba 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -55,9 +55,6 @@ def rerank_batch( ) -> List[Result]: self._enumerated_indices = [] - for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): - self._enumerated_indices.append(index) - rerank_results = [ Result( query=copy.deepcopy(request.query), @@ -71,6 +68,14 @@ def rerank_batch( for i in result.candidates: i.score = 0 + for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): + candidate_1 = math.floor( + (index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) + ) + candidate_2 = index % len(rerank_results[0].candidates) + if candidate_1 != candidate_2: + self._enumerated_indices.append(index) + end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: for index in range(0, end, self._batch_size): @@ -95,10 +100,10 @@ def rerank_batch( (update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) ) candidate_2 = update_index_copy % len(rerank_results[0].candidates) - + rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] - + if index + self._batch_size > end: progress_bar.update(end - index) else: From 0f7eb64f9b9bdb7be42ae928d837f196527968c1 Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Tue, 11 Feb 2025 23:23:23 -0500 Subject: [PATCH 08/14] fix --- src/rank_llm/rerank/pairwise/pairwise_rankllm.py | 14 ++------------ src/rank_llm/rerank/reranker.py | 2 +- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 8f47f8ba..71da9865 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -18,16 +18,6 @@ class PairwiseRankLLM(RankLLM, ABC): """ Abstract base class that all pairwise rerankers implement. - - All concrete children of RankLLM must implement these functions: - - rerank_batch - - run_llm_batched - - run_llm - - create_prompt_batched - - create_prompt - - get_num_tokens - - cost_per_1k_tokens - - num_output_tokens """ def __init__( @@ -59,7 +49,7 @@ def rerank_batch( Result( query=copy.deepcopy(request.query), candidates=copy.deepcopy(request.candidates), - ranking_exec_summary=[], + invocations_history=[], ) for request in requests ] @@ -74,7 +64,7 @@ def rerank_batch( ) candidate_2 = index % len(rerank_results[0].candidates) if candidate_1 != candidate_2: - self._enumerated_indices.append(index) + self._enumerated_indices.append(index) end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 7c8d5508..19c5c883 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -352,7 +352,7 @@ def create_model_coordinator( keys_and_defaults, **kwargs ) - agent = DuoT5( + model_coordinator = DuoT5( model=( model_full_paths[model_path] if model_path in model_full_paths From 545d37c1cc7e2bb116d93fd532af860cbc6a5490 Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Thu, 13 Feb 2025 01:29:30 -0500 Subject: [PATCH 09/14] fix pairwise --- src/rank_llm/rerank/pairwise/duot5.py | 120 ++++++++-------- .../rerank/pairwise/pairwise_rankllm.py | 136 ++++++++---------- 2 files changed, 117 insertions(+), 139 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index 5c901318..97886e23 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -27,10 +27,16 @@ def __init__( device=device, batch_size=batch_size, ) - + self._tokenizer = T5Tokenizer.from_pretrained(model) self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) self._context_size = context_size + + self._true_id = self._tokenizer.encode("true", add_special_tokens=False)[0] + self._false_id = self._tokenizer.encode("false", add_special_tokens=False)[0] + + def num_output_tokens(self) -> int: + return 1 def run_llm_batched( self, @@ -43,86 +49,74 @@ def run_llm_batched( gen_cfg.return_dict_in_generate = True gen_cfg.do_sample = False - all_outputs = [] - all_output_token_counts = [] - all_scores = [] - - batch_prompts = prompts - - token_prompts = self._tokenizer( - batch_prompts, padding=True, truncation=True, return_tensors="pt" + tokenized = self._tokenizer( + prompts, + padding=True, + truncation=True, + max_length=self._context_size, + return_tensors="pt" ).to(self._device) + input_ids = tokenized["input_ids"] - token_prompts = token_prompts["input_ids"] - - batch_outputs = self._llm.generate(token_prompts, generation_config=gen_cfg) - - batch_output_ids = batch_outputs.sequences - batch_logits = batch_outputs.scores + outputs = self._llm.generate(input_ids, generation_config=gen_cfg) + output_ids = outputs.sequences # (batch_size, sequence_length) + logits = outputs.scores # Tuple with one tensor (batch_size, vocab_size) since num_output_tokens == 1 + # Decode outputs batch_outputs = [ self._tokenizer.decode( - single_token_sequence, + seq, skip_special_tokens=True, spaces_between_special_tokens=False, ) - for single_token_sequence in batch_output_ids + for seq in output_ids ] - for logit_tensor in batch_logits[0]: - truth_logit = logit_tensor[1176] - false_logit = logit_tensor[6136] - score = math.exp(truth_logit) / ( - math.exp(truth_logit) + math.exp(false_logit) - ) + all_scores = [] + all_output_token_counts = [] + # Use the logits from the generated token (logits[0] is of shape (batch_size, vocab_size)) + for logit_tensor in logits[0]: + truth_logit = logit_tensor[self._true_id].item() + false_logit = logit_tensor[self._false_id].item() + score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) all_scores.append(score) - all_output_token_counts.append(self.num_output_tokens) + all_output_token_counts.append(self.num_output_tokens()) - all_outputs.extend(batch_outputs) - - return all_outputs, all_output_token_counts, all_scores + return batch_outputs, all_output_token_counts, all_scores def run_llm(self, prompt: str) -> Tuple[str, int, float]: - gen_cfg = GenerationConfig.from_model_config(self._llm.config) - gen_cfg.max_new_tokens = self.num_output_tokens() - gen_cfg.min_new_tokens = self.num_output_tokens() - gen_cfg.output_scores = True - gen_cfg.return_dict_in_generate = True - gen_cfg.do_sample = False - - token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( - self._device - ) - output = self._llm.generate(token_prompt, generation_config=gen_cfg) - output_ids = output.sequences - logits = output.scores + ret = self.run_llm_batched([prompt]) + return (ret[0][0], ret[1][0], ret[2][0]) - if self._llm.config.is_encoder_decoder: - output_ids = output_ids[0] - output_ids = output_ids[1:] + def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]: + query = self._replace_number(result.query.text) - outputs = self._tokenizer.decode( - output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + doc1_raw = self.convert_doc_to_prompt_content( + result.candidates[index1].doc, + max_length=self._context_size + ) + doc2_raw = self.convert_doc_to_prompt_content( + result.candidates[index2].doc, + max_length=self._context_size ) - truth_logit = logits[0][0][1176] - false_logit = logits[0][0][6136] - score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) - - return outputs, output_ids.size(0), score - - def num_output_tokens(self) -> int: - return 1 - - def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]: - query = result.query.text - query = self._replace_number(query) - doc1 = self.convert_doc_to_prompt_content(result.candidates[index1].doc, max_length=self._context_size) - doc2 = self.convert_doc_to_prompt_content(result.candidates[index2].doc, max_length=self._context_size) - doc1 = self._tokenizer.decode(self._tokenizer.encode(doc1)[:240])[:-4] - doc2 = self._tokenizer.decode(self._tokenizer.encode(doc2)[:240])[:-4] - prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant:" - prompt = prompt.replace("","") + doc1_tokens = self._tokenizer.encode( + doc1_raw, + truncation=True, + max_length=self._context_size + ) + doc2_tokens = self._tokenizer.encode( + doc2_raw, + truncation=True, + max_length=self._context_size + ) + + doc1 = self._tokenizer.decode(doc1_tokens, skip_special_tokens=True) + doc2 = self._tokenizer.decode(doc2_tokens, skip_special_tokens=True) + + prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant: " + prompt = prompt.replace("", "") + return prompt, self.get_num_tokens(prompt) def get_num_tokens(self, prompt: str) -> int: diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 71da9865..2bc3689c 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -43,7 +43,12 @@ def rerank_batch( logging: bool = False, **kwargs: Any, ) -> List[Result]: - self._enumerated_indices = [] + """ + Re-rank candidates in a pairwise fashion: + 1. Build a list of all pairwise comparisons. + 2. Process in batches: create prompts, run LLM, update scores. + 3. Sort candidates by final score in descending order. + """ rerank_results = [ Result( @@ -54,84 +59,66 @@ def rerank_batch( for request in requests ] + # zero-initialize candidate scores for result in rerank_results: - for i in result.candidates: - i.score = 0 - - for index in range(len(requests) * len(requests[0].candidates) * len(requests[0].candidates)): - candidate_1 = math.floor( - (index % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) - ) - candidate_2 = index % len(rerank_results[0].candidates) - if candidate_1 != candidate_2: - self._enumerated_indices.append(index) - - end = (len(rerank_results[0].candidates) - 1) * len(rerank_results[0].candidates) * len(requests) - with tqdm(total=end, desc="Progress through (q, d) pairs") as progress_bar: - for index in range(0, end, self._batch_size): - prompts, token_counts = self.create_prompt_batched( - results=rerank_results, index=index - ) - - outputs, output_tokens, scores = self.run_llm_batched(prompts=prompts) - - for update_index in range ( - index, - min( - index + self._batch_size, - end - ) - ): - update_index_copy = self._enumerated_indices[update_index] - query_number = math.floor( - update_index_copy / (len(rerank_results[0].candidates) ** 2) - ) - candidate_1 = math.floor( - (update_index_copy % (len(rerank_results[0].candidates) ** 2)) / len(rerank_results[0].candidates) - ) - candidate_2 = update_index_copy % len(rerank_results[0].candidates) - - rerank_results[query_number].candidates[candidate_1].score += scores[update_index - index] - rerank_results[query_number].candidates[candidate_2].score += 1 - scores[update_index - index] - - if index + self._batch_size > end: - progress_bar.update(end - index) - else: - progress_bar.update(self._batch_size) - - + for candidate in result.candidates: + candidate.score = 0 + + num_queries, num_pairs = len(rerank_results), 0 + self._enumerated_indices = [[] for _ in range(num_queries)] + for query_idx, res in enumerate(rerank_results): + num_candidates = len(res.candidates) + for i in range(num_candidates): + for j in range(i+1,num_candidates): + self._enumerated_indices[query_idx].append([i,j]) + num_pairs += len(self._enumerated_indices[query_idx]) + + with tqdm(total=num_pairs, desc="Progress through (q, d) pairs") as progress_bar: + for query_idx, pair_list in enumerate(self._enumerated_indices): + index = 0 + while index < len(pair_list): + prompts, token_counts = self.create_prompt_batched(rerank_results, query_idx, index) + + outputs, output_tokens, scores = self.run_llm_batched(prompts) + + for (i, j), score in zip(pair_list[index : index + len(scores)], scores): + rerank_results[query_idx].candidates[i].score += score + rerank_results[query_idx].candidates[j].score += (1 - score) + + index += self._batch_size + progress_bar.update(len(scores)) + for result in rerank_results: - result.candidates.sort( - key=cmp_to_key(self.candidate_comparator), reverse=True - ) + result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) return rerank_results def create_prompt_batched( - self, results: List[Result], index + self, + results: List[Result], + query_idx: int, + index: int ) -> Tuple[List[str], List[int]]: - prompts = [] - token_counts = [] - - for current_index in range( - index, - min(index + self._batch_size, len(results[0].candidates) * (len(results[0].candidates) - 1) * len(results)), - ): - current_index = self._enumerated_indices[current_index] - query_number = math.floor( - current_index / (len(results[0].candidates) ** 2) + """ + Create a batch of prompts for the given query_idx, taking pairs of candidates + from self._enumerated_indices[query_idx] in the range [index : index + batch_size]. + """ + prompts, token_counts = [], [] + + pair_list = self._enumerated_indices[query_idx] + end_index = min(index + self._batch_size, len(pair_list)) + + # Build prompts for each pair in [index, end_index) + for pair_idx in range(index, end_index): + i, j = pair_list[pair_idx] + prompt, tcount = self.create_prompt( + result=results[query_idx], + index1=i, + index2=j ) - candidate_1 = math.floor( - (current_index % (len(results[0].candidates) ** 2)) / len(results[0].candidates) - ) - candidate_2 = current_index % len(results[0].candidates) - - prompt, token_count = self.create_prompt( - result=results[query_number], index1=candidate_1, index2=candidate_2 - ) - prompts.append(prompt) - token_counts.append(token_count) + token_counts.append(tcount) + return prompts, token_counts def get_output_filename( @@ -164,12 +151,9 @@ def get_output_filename( ) def candidate_comparator(self, x: Candidate, y: Candidate) -> int: - if x.score < y.score: - return -1 - elif x.score > y.score: - return 1 - else: - return 0 + if x.score < y.score: return 1 + elif x.score > y.score: return -1 + else: return 0 def _replace_number(self, s: str) -> str: return re.sub(r"\[(\d+)\]", r"(\1)", s) From 93f4b944541a25933d8833979a5a808b409a3485 Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Thu, 13 Feb 2025 01:30:20 -0500 Subject: [PATCH 10/14] lint --- src/rank_llm/rerank/pairwise/duot5.py | 38 +++++++-------- .../rerank/pairwise/pairwise_rankllm.py | 48 ++++++++++--------- src/rank_llm/rerank/reranker.py | 2 +- src/rank_llm/retrieve_and_rerank.py | 2 +- 4 files changed, 47 insertions(+), 43 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index 97886e23..be0753dc 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -27,11 +27,11 @@ def __init__( device=device, batch_size=batch_size, ) - + self._tokenizer = T5Tokenizer.from_pretrained(model) self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) self._context_size = context_size - + self._true_id = self._tokenizer.encode("true", add_special_tokens=False)[0] self._false_id = self._tokenizer.encode("false", add_special_tokens=False)[0] @@ -54,13 +54,15 @@ def run_llm_batched( padding=True, truncation=True, max_length=self._context_size, - return_tensors="pt" + return_tensors="pt", ).to(self._device) input_ids = tokenized["input_ids"] outputs = self._llm.generate(input_ids, generation_config=gen_cfg) output_ids = outputs.sequences # (batch_size, sequence_length) - logits = outputs.scores # Tuple with one tensor (batch_size, vocab_size) since num_output_tokens == 1 + logits = ( + outputs.scores + ) # Tuple with one tensor (batch_size, vocab_size) since num_output_tokens == 1 # Decode outputs batch_outputs = [ @@ -78,7 +80,9 @@ def run_llm_batched( for logit_tensor in logits[0]: truth_logit = logit_tensor[self._true_id].item() false_logit = logit_tensor[self._false_id].item() - score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) + score = math.exp(truth_logit) / ( + math.exp(truth_logit) + math.exp(false_logit) + ) all_scores.append(score) all_output_token_counts.append(self.num_output_tokens()) @@ -88,32 +92,28 @@ def run_llm(self, prompt: str) -> Tuple[str, int, float]: ret = self.run_llm_batched([prompt]) return (ret[0][0], ret[1][0], ret[2][0]) - def create_prompt(self, result: Result, index1: int, index2: int) -> Tuple[str, int]: + def create_prompt( + self, result: Result, index1: int, index2: int + ) -> Tuple[str, int]: query = self._replace_number(result.query.text) doc1_raw = self.convert_doc_to_prompt_content( - result.candidates[index1].doc, - max_length=self._context_size + result.candidates[index1].doc, max_length=self._context_size ) doc2_raw = self.convert_doc_to_prompt_content( - result.candidates[index2].doc, - max_length=self._context_size + result.candidates[index2].doc, max_length=self._context_size ) - + doc1_tokens = self._tokenizer.encode( - doc1_raw, - truncation=True, - max_length=self._context_size + doc1_raw, truncation=True, max_length=self._context_size ) doc2_tokens = self._tokenizer.encode( - doc2_raw, - truncation=True, - max_length=self._context_size + doc2_raw, truncation=True, max_length=self._context_size ) - + doc1 = self._tokenizer.decode(doc1_tokens, skip_special_tokens=True) doc2 = self._tokenizer.decode(doc2_tokens, skip_special_tokens=True) - + prompt = f"Query: {query} Document0: {doc1} Document1: {doc2} Relevant: " prompt = prompt.replace("", "") diff --git a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py index 2bc3689c..8f5f8e51 100644 --- a/src/rank_llm/rerank/pairwise/pairwise_rankllm.py +++ b/src/rank_llm/rerank/pairwise/pairwise_rankllm.py @@ -1,6 +1,5 @@ import copy import logging -import math import re from abc import ABC from datetime import datetime @@ -15,6 +14,7 @@ logger = logging.getLogger(__name__) + class PairwiseRankLLM(RankLLM, ABC): """ Abstract base class that all pairwise rerankers implement. @@ -69,56 +69,57 @@ def rerank_batch( for query_idx, res in enumerate(rerank_results): num_candidates = len(res.candidates) for i in range(num_candidates): - for j in range(i+1,num_candidates): - self._enumerated_indices[query_idx].append([i,j]) + for j in range(i + 1, num_candidates): + self._enumerated_indices[query_idx].append([i, j]) num_pairs += len(self._enumerated_indices[query_idx]) - - with tqdm(total=num_pairs, desc="Progress through (q, d) pairs") as progress_bar: + + with tqdm( + total=num_pairs, desc="Progress through (q, d) pairs" + ) as progress_bar: for query_idx, pair_list in enumerate(self._enumerated_indices): index = 0 while index < len(pair_list): - prompts, token_counts = self.create_prompt_batched(rerank_results, query_idx, index) + prompts, token_counts = self.create_prompt_batched( + rerank_results, query_idx, index + ) outputs, output_tokens, scores = self.run_llm_batched(prompts) - for (i, j), score in zip(pair_list[index : index + len(scores)], scores): + for (i, j), score in zip( + pair_list[index : index + len(scores)], scores + ): rerank_results[query_idx].candidates[i].score += score - rerank_results[query_idx].candidates[j].score += (1 - score) + rerank_results[query_idx].candidates[j].score += 1 - score index += self._batch_size progress_bar.update(len(scores)) - + for result in rerank_results: result.candidates.sort(key=cmp_to_key(self.candidate_comparator)) return rerank_results def create_prompt_batched( - self, - results: List[Result], - query_idx: int, - index: int + self, results: List[Result], query_idx: int, index: int ) -> Tuple[List[str], List[int]]: """ Create a batch of prompts for the given query_idx, taking pairs of candidates from self._enumerated_indices[query_idx] in the range [index : index + batch_size]. """ prompts, token_counts = [], [] - + pair_list = self._enumerated_indices[query_idx] end_index = min(index + self._batch_size, len(pair_list)) - + # Build prompts for each pair in [index, end_index) for pair_idx in range(index, end_index): i, j = pair_list[pair_idx] prompt, tcount = self.create_prompt( - result=results[query_idx], - index1=i, - index2=j + result=results[query_idx], index1=i, index2=j ) prompts.append(prompt) token_counts.append(tcount) - + return prompts, token_counts def get_output_filename( @@ -151,9 +152,12 @@ def get_output_filename( ) def candidate_comparator(self, x: Candidate, y: Candidate) -> int: - if x.score < y.score: return 1 - elif x.score > y.score: return -1 - else: return 0 + if x.score < y.score: + return 1 + elif x.score > y.score: + return -1 + else: + return 0 def _replace_number(self, s: str) -> str: return re.sub(r"\[(\d+)\]", r"(\1)", s) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 19c5c883..cde2e5cb 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -11,8 +11,8 @@ ) from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeGenai, SafeOpenai from rank_llm.rerank.listwise.rank_fid import RankFiDDistill, RankFiDScore -from rank_llm.rerank.pointwise.monot5 import MonoT5 from rank_llm.rerank.pairwise.duot5 import DuoT5 +from rank_llm.rerank.pointwise.monot5 import MonoT5 from rank_llm.rerank.rankllm import RankLLM diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 3a3d5393..fc81c577 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -54,7 +54,7 @@ def retrieve_and_rerank( dataset=dataset, **kwargs, ) - + for request in requests: request.candidates = request.candidates[:top_k_retrieve] From 1234b3bccc1546369ec2d8a597ebe01ac095e00d Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Thu, 13 Feb 2025 01:39:04 -0500 Subject: [PATCH 11/14] cleanup monot5 --- src/rank_llm/rerank/pairwise/duot5.py | 18 ++++----- src/rank_llm/rerank/pointwise/monot5.py | 50 ++++++------------------- 2 files changed, 19 insertions(+), 49 deletions(-) diff --git a/src/rank_llm/rerank/pairwise/duot5.py b/src/rank_llm/rerank/pairwise/duot5.py index be0753dc..969d79f9 100644 --- a/src/rank_llm/rerank/pairwise/duot5.py +++ b/src/rank_llm/rerank/pairwise/duot5.py @@ -35,9 +35,6 @@ def __init__( self._true_id = self._tokenizer.encode("true", add_special_tokens=False)[0] self._false_id = self._tokenizer.encode("false", add_special_tokens=False)[0] - def num_output_tokens(self) -> int: - return 1 - def run_llm_batched( self, prompts: List[str], @@ -59,12 +56,9 @@ def run_llm_batched( input_ids = tokenized["input_ids"] outputs = self._llm.generate(input_ids, generation_config=gen_cfg) - output_ids = outputs.sequences # (batch_size, sequence_length) - logits = ( - outputs.scores - ) # Tuple with one tensor (batch_size, vocab_size) since num_output_tokens == 1 + output_ids = outputs.sequences + logits = outputs.scores - # Decode outputs batch_outputs = [ self._tokenizer.decode( seq, @@ -74,8 +68,7 @@ def run_llm_batched( for seq in output_ids ] - all_scores = [] - all_output_token_counts = [] + all_scores, all_output_token_counts = [], [] # Use the logits from the generated token (logits[0] is of shape (batch_size, vocab_size)) for logit_tensor in logits[0]: truth_logit = logit_tensor[self._true_id].item() @@ -90,7 +83,7 @@ def run_llm_batched( def run_llm(self, prompt: str) -> Tuple[str, int, float]: ret = self.run_llm_batched([prompt]) - return (ret[0][0], ret[1][0], ret[2][0]) + return ret[0][0], ret[1][0], ret[2][0] def create_prompt( self, result: Result, index1: int, index2: int @@ -122,5 +115,8 @@ def create_prompt( def get_num_tokens(self, prompt: str) -> int: return len(self._tokenizer.encode(prompt)) + def num_output_tokens(self) -> int: + return 1 + def cost_per_1k_token(self, input_token: bool) -> float: return 0 diff --git a/src/rank_llm/rerank/pointwise/monot5.py b/src/rank_llm/rerank/pointwise/monot5.py index 304908b7..46566fbd 100644 --- a/src/rank_llm/rerank/pointwise/monot5.py +++ b/src/rank_llm/rerank/pointwise/monot5.py @@ -32,6 +32,9 @@ def __init__( self._llm = T5ForConditionalGeneration.from_pretrained(model).to(self._device) self._context_size = context_size + self._true_id = self._tokenizer.encode("true", add_special_tokens=False)[0] + self._false_id = self._tokenizer.encode("false", add_special_tokens=False)[0] + def run_llm_batched( self, prompts: List[str], @@ -42,15 +45,10 @@ def run_llm_batched( gen_cfg.output_scores = True gen_cfg.return_dict_in_generate = True gen_cfg.do_sample = False - - all_outputs = [] - all_output_token_counts = [] - all_scores = [] - - batch_prompts = prompts + all_outputs, all_output_token_counts, all_scores = [], [], [] token_prompts = self._tokenizer( - batch_prompts, padding=True, truncation=True, return_tensors="pt" + prompts, padding=True, truncation=True, return_tensors="pt" ).to(self._device) token_prompts = token_prompts["input_ids"] @@ -70,8 +68,8 @@ def run_llm_batched( ] for logit_tensor in batch_logits[0]: - truth_logit = logit_tensor[1176] - false_logit = logit_tensor[6136] + truth_logit = logit_tensor[self._true_id] + false_logit = logit_tensor[self._false_id] score = math.exp(truth_logit) / ( math.exp(truth_logit) + math.exp(false_logit) ) @@ -83,35 +81,8 @@ def run_llm_batched( return all_outputs, all_output_token_counts, all_scores def run_llm(self, prompt: str) -> Tuple[str, int, float]: - gen_cfg = GenerationConfig.from_model_config(self._llm.config) - gen_cfg.max_new_tokens = self.num_output_tokens() - gen_cfg.min_new_tokens = self.num_output_tokens() - gen_cfg.output_scores = True - gen_cfg.return_dict_in_generate = True - gen_cfg.do_sample = False - - token_prompt = self._tokenizer.encode(prompt, return_tensors="pt").to( - self._device - ) - output = self._llm.generate(token_prompt, generation_config=gen_cfg) - output_ids = output.sequences - logits = output.scores - - if self._llm.config.is_encoder_decoder: - output_ids = output_ids[0] - output_ids = output_ids[1:] - - outputs = self._tokenizer.decode( - output_ids, skip_special_tokens=True, spaces_between_special_tokens=False - ) - truth_logit = logits[0][0][1176] - false_logit = logits[0][0][6136] - score = math.exp(truth_logit) / (math.exp(truth_logit) + math.exp(false_logit)) - - return outputs, output_ids.size(0), score - - def num_output_tokens(self) -> int: - return 1 + ret = self.run_llm_batched([prompt]) + return ret[0][0], ret[1][0], ret[2][0] def create_prompt(self, result: Result, index: int) -> Tuple[str, int]: query = result.query.text @@ -129,5 +100,8 @@ def create_prompt(self, result: Result, index: int) -> Tuple[str, int]: def get_num_tokens(self, prompt: str) -> int: return len(self._tokenizer.encode(prompt)) + def num_output_tokens(self) -> int: + return 1 + def cost_per_1k_token(self, input_token: bool) -> float: return 0 From 6e62cbb6f5e9d2a5536f04e910f26566118bca83 Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Fri, 14 Feb 2025 05:31:36 +0000 Subject: [PATCH 12/14] change duo model path --- src/rank_llm/rerank/reranker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index cde2e5cb..9bc31e66 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -340,7 +340,7 @@ def create_model_coordinator( # using monot5 print(f"Loading {model_path} ...") - model_full_paths = {"duot5": "castorini/duot5-3b-med-msmarco"} + model_full_paths = {"duot5": "castorini/duot5-3b-msmarco-10k"} keys_and_defaults = [ ("prompt_mode", PromptMode.DUOT5), From 7f326ffa9a29cfc5128a4307260434ad1489136c Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Fri, 14 Feb 2025 05:41:14 +0000 Subject: [PATCH 13/14] add demo --- src/rank_llm/demo/rerank_duot5.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 src/rank_llm/demo/rerank_duot5.py diff --git a/src/rank_llm/demo/rerank_duot5.py b/src/rank_llm/demo/rerank_duot5.py new file mode 100644 index 00000000..b48d6ebe --- /dev/null +++ b/src/rank_llm/demo/rerank_duot5.py @@ -0,0 +1,30 @@ +import os +import sys +from pathlib import Path + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(SCRIPT_DIR) +parent = os.path.dirname(parent) +sys.path.append(parent) + +from rank_llm.data import DataWriter +from rank_llm.rerank import Reranker +from rank_llm.rerank.pairwise.duot5 import DuoT5 +from rank_llm.retrieve.retriever import Retriever + +dataset = "dl20" +requests = Retriever.from_dataset_with_prebuilt_index(dataset, k=50) +duot5_model_coordinator = DuoT5("castorini/duot5-3b-msmarco-10k") +m_reranker = Reranker(duot5_model_coordinator) +kwargs = {"populate_invocations_history": True} +rerank_results = m_reranker.rerank_batch(requests, **kwargs) +print(rerank_results) + +# write rerank results +writer = DataWriter(rerank_results) +Path(f"demo_outputs/").mkdir(parents=True, exist_ok=True) +writer.write_in_jsonl_format(f"demo_outputs/rerank_results.jsonl") +writer.write_in_trec_eval_format(f"demo_outputs/rerank_results.txt") +writer.write_inference_invocations_history( + f"demo_outputs/inference_invocations_history.json" +) \ No newline at end of file From c1ee1a25c1563c4941a6551acef13b60d3700d67 Mon Sep 17 00:00:00 2001 From: Ryan Nguyen Date: Sun, 16 Feb 2025 12:16:11 -0500 Subject: [PATCH 14/14] lint --- src/rank_llm/demo/rerank_duot5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rank_llm/demo/rerank_duot5.py b/src/rank_llm/demo/rerank_duot5.py index b48d6ebe..4c89013c 100644 --- a/src/rank_llm/demo/rerank_duot5.py +++ b/src/rank_llm/demo/rerank_duot5.py @@ -27,4 +27,4 @@ writer.write_in_trec_eval_format(f"demo_outputs/rerank_results.txt") writer.write_inference_invocations_history( f"demo_outputs/inference_invocations_history.json" -) \ No newline at end of file +)