Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions ragas/src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
VertexAI,
]

MODELS_NOT_SUPPORT_TEMP = ['o3-mini', 'o4-mini', 'o3']


def is_multiple_completion_supported(llm: BaseLanguageModel) -> bool:
"""Return whether the given LLM supports n-completion."""
Expand Down Expand Up @@ -205,8 +207,13 @@ def generate_text(
if temperature is None:
temperature = self.get_temperature(n=n)
if hasattr(self.langchain_llm, "temperature"):
self.langchain_llm.temperature = temperature # type: ignore
old_temperature = temperature
if hasattr(self.langchain_llm, "model_name"):
if self.langchain_llm.model_name not in MODELS_NOT_SUPPORT_TEMP:
self.langchain_llm.temperature = temperature # type: ignore
old_temperature = temperature
else:
self.langchain_llm.temperature = temperature
old_temperature = temperature

if is_multiple_completion_supported(self.langchain_llm):
result = self.langchain_llm.generate_prompt(
Expand Down Expand Up @@ -245,8 +252,13 @@ async def agenerate_text(
if temperature is None:
temperature = self.get_temperature(n=n)
if hasattr(self.langchain_llm, "temperature"):
self.langchain_llm.temperature = temperature # type: ignore
old_temperature = temperature
if hasattr(self.langchain_llm, "model_name"):
if self.langchain_llm.model_name not in MODELS_NOT_SUPPORT_TEMP:
self.langchain_llm.temperature = temperature # type: ignore
old_temperature = temperature
else:
self.langchain_llm.temperature = temperature
old_temperature = temperature

# handle n
if hasattr(self.langchain_llm, "n"):
Expand Down