diff --git a/nemoguardrails/embeddings/providers/__init__.py b/nemoguardrails/embeddings/providers/__init__.py index c9a8f2896..2527e9e6e 100644 --- a/nemoguardrails/embeddings/providers/__init__.py +++ b/nemoguardrails/embeddings/providers/__init__.py @@ -18,7 +18,7 @@ from typing import Optional, Type -from . import fastembed, nim, openai, sentence_transformers +from . import azure, fastembed, nim, openai, sentence_transformers from .base import EmbeddingModel from .registry import EmbeddingProviderRegistry @@ -63,6 +63,7 @@ def register_embedding_provider( # Add all the implemented embedding providers to the registry. # As we are not using the `Registered` class, we need to manually register the providers. +register_embedding_provider(azure.AzureOpenAIEmbeddingModel) register_embedding_provider(fastembed.FastEmbedEmbeddingModel) register_embedding_provider(openai.OpenAIEmbeddingModel) register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel) diff --git a/nemoguardrails/embeddings/providers/azure.py b/nemoguardrails/embeddings/providers/azure.py new file mode 100644 index 000000000..f9a3fad8d --- /dev/null +++ b/nemoguardrails/embeddings/providers/azure.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from contextvars import ContextVar +from typing import List + +from .base import EmbeddingModel + +# We set the Azure OpenAI async client in an asyncio context variable because we need it +# to be scoped at the asyncio loop level. The client caches it somewhere, and if the loop +# is changed, it will fail. +async_client_var: ContextVar = ContextVar("azure_async_client", default=None) + + +class AzureOpenAIEmbeddingModel(EmbeddingModel): + """Embedding model using Azure OpenAI API. + + Args: + embedding_model (str): The name of the embedding model deployment. + azure_endpoint (str): The Azure OpenAI endpoint URL. + api_version (str): The API version to use (defaults to "2024-02-01"). + **kwargs: Additional arguments passed to AzureOpenAI client. + + Attributes: + model (str): The name of the embedding model deployment. + embedding_size (int): The size of the embeddings. + + Methods: + encode: Encode a list of documents into embeddings. + encode_async: Asynchronously encode a list of documents into embeddings. + """ + + engine_name = "azure" + + def __init__( + self, + embedding_model: str, + azure_endpoint: str = None, + api_version: str = "2024-02-01", + **kwargs, + ): + try: + import openai + from openai import AsyncAzureOpenAI, AzureOpenAI + except ImportError: + raise ImportError( + "Could not import openai, please install it with " + "`pip install openai`." + ) + if openai.__version__ < "1.0.0": + raise RuntimeError( + "`openai<1.0.0` is no longer supported. " + "Please upgrade using `pip install openai>=1.0.0`." + ) + + self.model = embedding_model + + # Set default values for Azure OpenAI configuration + client_kwargs = {"api_version": api_version, **kwargs} + + # Add azure_endpoint if provided + if azure_endpoint: + client_kwargs["azure_endpoint"] = azure_endpoint + + self.client = AzureOpenAI(**client_kwargs) + + # Azure OpenAI supports the same embedding models as OpenAI + self.embedding_size_dict = { + "text-embedding-ada-002": 1536, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + } + + # For Azure, the model name might be the deployment name, so we check if we know the size + if self.model in self.embedding_size_dict: + self.embedding_size = self.embedding_size_dict[self.model] + else: + # Perform a first encoding to get the embedding size + # This handles custom deployment names + try: + self.embedding_size = len(self.encode(["test"])[0]) + except Exception as e: + # If we can't determine size, default to common size + self.embedding_size = 1536 + + async def encode_async(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The encoded embeddings. + + """ + loop = asyncio.get_running_loop() + embeddings = await loop.run_in_executor(None, self.encode, documents) + + # NOTE: The async implementation below has some edge cases because of + # httpx and async and returns "Event loop is closed." errors. Falling back to + # a thread-based implementation for now. + + return embeddings + + def encode(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The encoded embeddings. + + """ + + # Make embedding request to Azure OpenAI API + res = self.client.embeddings.create(input=documents, model=self.model) + embeddings = [record.embedding for record in res.data] + + return embeddings diff --git a/tests/test_azure_embedding_provider.py b/tests/test_azure_embedding_provider.py new file mode 100644 index 000000000..4de95211d --- /dev/null +++ b/tests/test_azure_embedding_provider.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, patch + +import pytest + +from nemoguardrails.embeddings.providers import init_embedding_model +from nemoguardrails.embeddings.providers.registry import \ + EmbeddingProviderRegistry + + +def test_azure_embedding_provider_registration(): + """Test that the Azure embedding provider is properly registered.""" + registry = EmbeddingProviderRegistry() + + # Check that azure is in the registry + assert "azure" in registry.items + + # Check that we can get the provider class + provider_class = registry.get("azure") + assert provider_class is not None + + +@patch("openai.AzureOpenAI") +def test_azure_embedding_model_initialization(mock_azure_openai): + """Test Azure embedding model initialization with different parameters.""" + # Mock the AzureOpenAI client + mock_client = Mock() + mock_azure_openai.return_value = mock_client + + # Mock the response for size detection + mock_response = Mock() + mock_response.data = [Mock()] + mock_response.data[0].embedding = [0.1] * 1536 + mock_client.embeddings.create.return_value = mock_response + + # Test basic initialization + model = init_embedding_model( + embedding_model="text-embedding-ada-002", + embedding_engine="azure", + embedding_params={ + "azure_endpoint": "https://example.openai.azure.com/", + "api_key": "test-key", + }, + ) + + assert model.model == "text-embedding-ada-002" + assert model.embedding_size == 1536 + + # Verify AzureOpenAI was called with correct parameters + mock_azure_openai.assert_called_with( + api_version="2024-02-01", + azure_endpoint="https://example.openai.azure.com/", + api_key="test-key", + ) + + +@patch("openai.AzureOpenAI") +def test_azure_embedding_model_custom_deployment(mock_azure_openai): + """Test Azure embedding model with custom deployment name.""" + # Mock the AzureOpenAI client + mock_client = Mock() + mock_azure_openai.return_value = mock_client + + # Mock the response for custom deployment + mock_response = Mock() + mock_response.data = [Mock()] + mock_response.data[0].embedding = [0.1] * 3072 # Different size + mock_client.embeddings.create.return_value = mock_response + + # Test with custom deployment name (not in the known models dict) + model = init_embedding_model( + embedding_model="my-custom-embedding-deployment", + embedding_engine="azure", + embedding_params={ + "azure_endpoint": "https://example.openai.azure.com/", + "api_key": "test-key", + "api_version": "2023-12-01-preview", + }, + ) + + assert model.model == "my-custom-embedding-deployment" + assert model.embedding_size == 3072 + + # Verify the test call was made to determine embedding size + mock_client.embeddings.create.assert_called() + + +@patch("openai.AzureOpenAI") +def test_azure_embedding_encode_method(mock_azure_openai): + """Test the encode method of Azure embedding model.""" + # Mock the AzureOpenAI client and response + mock_client = Mock() + mock_azure_openai.return_value = mock_client + + # Mock embeddings response for the actual encoding call + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2, 0.3]), + Mock(embedding=[0.4, 0.5, 0.6]), + ] + mock_client.embeddings.create.return_value = mock_response + + model = init_embedding_model( + embedding_model="text-embedding-encode-test", # Use unique name to avoid cache + embedding_engine="azure", + embedding_params={ + "azure_endpoint": "https://example.openai.azure.com/", + "api_key": "test-key", + }, + ) + + # Test encoding + documents = ["Hello world", "Test document"] + embeddings = model.encode(documents) + + assert len(embeddings) == 2 + assert embeddings[0] == [0.1, 0.2, 0.3] + assert embeddings[1] == [0.4, 0.5, 0.6] + + # Verify the API calls were made correctly + # Since this is an unknown model, there will be 2 calls: 1 for size detection, 1 for actual encoding + assert mock_client.embeddings.create.call_count == 2 + # The final call should be for our documents + mock_client.embeddings.create.assert_called_with( + input=documents, model="text-embedding-encode-test" + ) + + +def test_azure_embedding_missing_openai_import(): + """Test proper error handling when openai is not installed.""" + # Mock the import to raise ImportError + with patch.dict("sys.modules", {"openai": None}): + with pytest.raises(ImportError, match="Could not import openai"): + from nemoguardrails.embeddings.providers.azure import \ + AzureOpenAIEmbeddingModel + + AzureOpenAIEmbeddingModel( + embedding_model="test-model", + azure_endpoint="https://example.openai.azure.com/", + ) + + +def test_azure_embedding_provider_in_supported_list(): + """Test that Azure is now in the list of supported embedding engines.""" + from nemoguardrails.embeddings.providers import EmbeddingProviderRegistry + + registry = EmbeddingProviderRegistry() + supported_engines = registry.list() + + assert "azure" in supported_engines