Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions nemoguardrails/llm/models/langchain_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _init_chat_completion_model(model_name: str, provider_name: str, kwargs: Dic
raise


def _init_text_completion_model(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM:
def _init_text_completion_model(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM | None:
"""Initialize a text completion model.

Args:
Expand All @@ -234,22 +234,24 @@ def _init_text_completion_model(model_name: str, provider_name: str, kwargs: Dic
kwargs: Additional arguments to pass to the model initialization

Returns:
An initialized text completion model

Raises:
RuntimeError: If the provider is not found
An initialized text completion model, or None if the provider is not found
"""
provider_cls = _get_text_completion_provider(provider_name)
try:
provider_cls = _get_text_completion_provider(provider_name)
except RuntimeError:
return None

if provider_cls is None:
raise ValueError()
return None

kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
# remove stream_usage parameter as it's not supported by text completion APIs
# (e.g., OpenAI's AsyncCompletions.create() doesn't accept this parameter)
kwargs.pop("stream_usage", None)
return provider_cls(**kwargs)


def _init_community_chat_models(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseChatModel:
def _init_community_chat_models(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseChatModel | None:
"""Initialize community chat models.

Args:
Expand All @@ -264,14 +266,19 @@ def _init_community_chat_models(model_name: str, provider_name: str, kwargs: Dic
ImportError: If langchain_community is not installed
ModelInitializationError: If model initialization fails
"""
provider_cls = _get_chat_completion_provider(provider_name)
try:
provider_cls = _get_chat_completion_provider(provider_name)
except RuntimeError:
return None

if provider_cls is None:
raise ValueError()
return None

kwargs = _update_model_kwargs(provider_cls, model_name, kwargs)
return provider_cls(**kwargs)


def _init_gpt35_turbo_instruct(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM:
def _init_gpt35_turbo_instruct(model_name: str, provider_name: str, kwargs: Dict[str, Any]) -> BaseLLM | None:
"""Initialize GPT-3.5 Turbo Instruct model.

Currently init_chat_model from langchain infers this as a chat model.
Expand Down
6 changes: 2 additions & 4 deletions tests/llm_providers/test_langchain_initialization_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def test_init_community_chat_models_no_provider(self):
"nemoguardrails.llm.models.langchain_initializer._get_chat_completion_provider"
) as mock_get_provider:
mock_get_provider.return_value = None
with pytest.raises(ValueError):
_init_community_chat_models("community-model", "provider", {})
assert _init_community_chat_models("community-model", "provider", {}) is None


class TestTextCompletionInitializer:
Expand Down Expand Up @@ -156,8 +155,7 @@ def test_init_text_completion_model_no_provider(self):
"nemoguardrails.llm.models.langchain_initializer._get_text_completion_provider"
) as mock_get_provider:
mock_get_provider.return_value = None
with pytest.raises(ValueError):
_init_text_completion_model("text-model", "provider", {})
assert _init_text_completion_model("text-model", "provider", {}) is None


class TestUpdateModelKwargs:
Expand Down
26 changes: 26 additions & 0 deletions tests/llm_providers/test_langchain_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,29 @@ def test_text_completion_supports_chat_mode(mock_initializers):
mock_initializers["chat"].assert_called_once()
mock_initializers["community"].assert_called_once()
mock_initializers["text"].assert_called_once()


def test_exception_not_masked_by_none_return(mock_initializers):
"""Test that an exception from an initializer is preserved when later ones return None.

For example: if community chat throws an error (e.g., invalid API key), but text completion
returns None because that provider type doesn't exist, the community error should be raised.
"""
mock_initializers["special"].return_value = None
mock_initializers["chat"].return_value = None
mock_initializers["community"].side_effect = ValueError("Invalid API key for provider")
mock_initializers["text"].return_value = None # Provider not found, returns None

with pytest.raises(ModelInitializationError, match="Invalid API key for provider"):
init_langchain_model("community-model", "provider", "chat", {})


def test_import_error_prioritized_over_other_exceptions(mock_initializers):
"""Test that ImportError is surfaced to help users know when packages are missing."""
mock_initializers["special"].return_value = None
mock_initializers["chat"].side_effect = ValueError("Some config error")
mock_initializers["community"].side_effect = ImportError("Missing langchain_community package")
mock_initializers["text"].return_value = None

with pytest.raises(ModelInitializationError, match="Missing langchain_community package"):
init_langchain_model("model", "provider", "chat", {})