-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemini API with RankLLM #162
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work! Just a couple things that can be a little more concise and less redundant. Take a look and change as you see fit.
src/rank_llm/rerank/api_keys.py
Outdated
@@ -32,3 +32,13 @@ def get_azure_openai_args() -> Dict[str, str]: | |||
list(azure_args.values()) | |||
), "Ensure that `AZURE_OPENAI_API_BASE`, `AZURE_OPENAI_API_VERSION` are set" | |||
return azure_args | |||
|
|||
# edit required here to get gemini keys | |||
# add function to get gemini key similar to get openai_api_key() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should remove these comments, not necessary anymore
@@ -2,10 +2,12 @@ | |||
from .rank_listwise_os_llm import RankListwiseOSLLM | |||
from .vicuna_reranker import VicunaReranker | |||
from .zephyr_reranker import ZephyrReranker | |||
from .rank_gemini import SafeGemini |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for clarity when future users try utilizing gemini, it might be a good idea to rename this class to "GeminiReranker". SafeOpenai was named as such because of the external packages they use (the main one being called SafeOpenai), so it wouldn't make sense to apply that to Gemini.
psg_ids.append(psg_id) | ||
message += f'QUESTION = "{query}"\n' | ||
message += "PASSAGES = [" + ", ".join(psg_ids) + "]\n" | ||
message += "SORTED_PASSAGES = [\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to combine these 3 += for message into just 1 line? Just to be concise.
while True: | ||
try: | ||
if completion_mode == self.CompletionMode.CHAT: | ||
model = genai.GenerativeModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on if this function is called multiple times or not, might be a good idea to configure the model in init and save it as self.model (or something similar). Then, the model is only configured and loaded once for the entire class, instead of loaded again every time this function is called. Could save a lot of time.
return_text=True, | ||
**{model_key: self._model}, | ||
) | ||
token_counter = genai.GenerativeModel(self.model_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here as for the _call_completion function, should probably do this in the initialize.
def get_num_tokens(self, prompt: Union[str, List[Dict[str, str]]]) -> int: | ||
"""Returns the number of tokens used by a list of messages in prompt.""" | ||
num_tokens = 0 | ||
model = genai.GenerativeModel(self.model_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here as with the _call_completion function.
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, | ||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, | ||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | ||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, # ask how to block civic integrity and add here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure myself on the civic integrity thing, might want to ask Ronak.
} | ||
] | ||
} | ||
#message = [{"role": "user", "content": message}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove this comment, unless it might be useful in future.
else: | ||
response = model.count_tokens(prompt).total_tokens | ||
num_tokens += response | ||
# num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> check later for how to approch this issue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't worry too much about this function. Don't think the values it returns are used anywhere as of right now. If you want, message me later about this just to clarify what exactly the problem is.
@@ -2,10 +2,12 @@ | |||
from .rank_listwise_os_llm import RankListwiseOSLLM | |||
from .vicuna_reranker import VicunaReranker | |||
from .zephyr_reranker import ZephyrReranker | |||
from .rank_gemini import SafeGemini |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .rank_gemini import SafeGemini | |
try: | |
from .rank_gemini import SafeGemini | |
except ImportError: | |
SafeGemini = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I run RankZephyr normally without installing gemini dependencies then I will get an ImportError. I think it is a good idea to use try except so that we don't force everyone to download gemini dependencies.
@@ -15,6 +15,7 @@ class PromptMode(Enum): | |||
LRL = "LRL" | |||
MONOT5 = "monot5" | |||
LiT5 = "LiT5" | |||
GEMINI = "GEMINI" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about this prompt makes it GEMINI specific?
please implement either RANK_GPT or RANK_GPT_APEER prompts with the gemini reranker.
Ideally, we don't want to have too many prompt modes, MONOT5 and LiT5 are added here since they are T5 models trained with specfic prompts and won't work with RANK_GPT, RANK_GPT_APEER prompts.
Commits from this PR are merges as part of PR #176 |
Pull Request Checklist
Reference Issue
Please provide the reference to issue this PR is addressing (# followed by the issue number). If there is no associated issue, write "N/A".
ref:
Checklist Items
Before submitting your pull request, please review these items:
PR Type
What kind of change does this PR introduce?
Dependencies
aside from normal rankllm setup run:
pip install -U -q "google-generativeai>=0.8.2"
Example command
to use the added features, you should run or a similar command