Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini API with RankLLM #162

Closed
wants to merge 4 commits into from
Closed

Gemini API with RankLLM #162

wants to merge 4 commits into from

Conversation

natek-1
Copy link
Contributor

@natek-1 natek-1 commented Jan 9, 2025

Pull Request Checklist

Reference Issue

Please provide the reference to issue this PR is addressing (# followed by the issue number). If there is no associated issue, write "N/A".

ref:

Checklist Items

Before submitting your pull request, please review these items:

  • Have you followed the contributing guidelines?
  • Have you verified that there are no existing Pull Requests for the same update/change?
  • Have you updated any relevant documentation or added new tests where needed?

PR Type

What kind of change does this PR introduce?

  • Bugfix
  • Feature
  • Code style update (formatting, local variables)
  • Refactoring (no functional changes, no api changes)
  • Documentation content changes
  • Other...
  • Description: mplemented Gemini api(ex gemini-2.0-flash-exp) models into the generation component of rankllm.

Dependencies

aside from normal rankllm setup run:

pip install -U -q "google-generativeai>=0.8.2"

Example command

to use the added features, you should run or a similar command

python src/rank_llm/scripts/run_rank_llm.py  --model_path=gemini-2.0-flash-exp --top_k_candidates=20 --dataset=dl19   --retrieval_method=bm25 --prompt_mode=GEMINI --context_size=8192

Copy link

@Yuv-sue1005 Yuv-sue1005 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! Just a couple things that can be a little more concise and less redundant. Take a look and change as you see fit.

@@ -32,3 +32,13 @@ def get_azure_openai_args() -> Dict[str, str]:
list(azure_args.values())
), "Ensure that `AZURE_OPENAI_API_BASE`, `AZURE_OPENAI_API_VERSION` are set"
return azure_args

# edit required here to get gemini keys
# add function to get gemini key similar to get openai_api_key()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should remove these comments, not necessary anymore

@@ -2,10 +2,12 @@
from .rank_listwise_os_llm import RankListwiseOSLLM
from .vicuna_reranker import VicunaReranker
from .zephyr_reranker import ZephyrReranker
from .rank_gemini import SafeGemini

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarity when future users try utilizing gemini, it might be a good idea to rename this class to "GeminiReranker". SafeOpenai was named as such because of the external packages they use (the main one being called SafeOpenai), so it wouldn't make sense to apply that to Gemini.

psg_ids.append(psg_id)
message += f'QUESTION = "{query}"\n'
message += "PASSAGES = [" + ", ".join(psg_ids) + "]\n"
message += "SORTED_PASSAGES = [\n"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to combine these 3 += for message into just 1 line? Just to be concise.

while True:
try:
if completion_mode == self.CompletionMode.CHAT:
model = genai.GenerativeModel(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on if this function is called multiple times or not, might be a good idea to configure the model in init and save it as self.model (or something similar). Then, the model is only configured and loaded once for the entire class, instead of loaded again every time this function is called. Could save a lot of time.

return_text=True,
**{model_key: self._model},
)
token_counter = genai.GenerativeModel(self.model_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here as for the _call_completion function, should probably do this in the initialize.

def get_num_tokens(self, prompt: Union[str, List[Dict[str, str]]]) -> int:
"""Returns the number of tokens used by a list of messages in prompt."""
num_tokens = 0
model = genai.GenerativeModel(self.model_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here as with the _call_completion function.

{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, # ask how to block civic integrity and add here

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure myself on the civic integrity thing, might want to ask Ronak.

}
]
}
#message = [{"role": "user", "content": message}]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this comment, unless it might be useful in future.

else:
response = model.count_tokens(prompt).total_tokens
num_tokens += response
# num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> check later for how to approch this issue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't worry too much about this function. Don't think the values it returns are used anywhere as of right now. If you want, message me later about this just to clarify what exactly the problem is.

@@ -2,10 +2,12 @@
from .rank_listwise_os_llm import RankListwiseOSLLM
from .vicuna_reranker import VicunaReranker
from .zephyr_reranker import ZephyrReranker
from .rank_gemini import SafeGemini
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .rank_gemini import SafeGemini
try:
from .rank_gemini import SafeGemini
except ImportError:
SafeGemini = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I run RankZephyr normally without installing gemini dependencies then I will get an ImportError. I think it is a good idea to use try except so that we don't force everyone to download gemini dependencies.

@@ -15,6 +15,7 @@ class PromptMode(Enum):
LRL = "LRL"
MONOT5 = "monot5"
LiT5 = "LiT5"
GEMINI = "GEMINI"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this prompt makes it GEMINI specific?
please implement either RANK_GPT or RANK_GPT_APEER prompts with the gemini reranker.
Ideally, we don't want to have too many prompt modes, MONOT5 and LiT5 are added here since they are T5 models trained with specfic prompts and won't work with RANK_GPT, RANK_GPT_APEER prompts.

@sahel-sh sahel-sh mentioned this pull request Feb 10, 2025
9 tasks
@sahel-sh
Copy link
Member

Commits from this PR are merges as part of PR #176

@sahel-sh sahel-sh closed this Feb 10, 2025
@sahel-sh
Copy link
Member

@natek-1 thank you for drafting this PR. I made sure your commits are pulled into the history of PR #176 so that your contribution is not lost!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants