|
1 | | -"""LangChain implementation of AIProvider for LaunchDarkly AI SDK.""" |
| 1 | +"""LangChain provider module for LaunchDarkly AI SDK.""" |
2 | 2 |
|
3 | | -from typing import Any, Dict, List, Optional |
| 3 | +from ldai.providers.langchain.langchain_provider import LangChainProvider |
4 | 4 |
|
5 | | -from langchain_core.language_models.chat_models import BaseChatModel |
6 | | -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage |
7 | | - |
8 | | -from ldai.models import AIConfigKind, LDMessage |
9 | | -from ldai.providers.ai_provider import AIProvider |
10 | | -from ldai.providers.types import ChatResponse, LDAIMetrics, StructuredResponse |
11 | | -from ldai.tracker import TokenUsage |
12 | | - |
13 | | - |
14 | | -class LangChainProvider(AIProvider): |
15 | | - """ |
16 | | - LangChain implementation of AIProvider. |
17 | | - |
18 | | - This provider integrates LangChain models with LaunchDarkly's tracking capabilities. |
19 | | - """ |
20 | | - |
21 | | - def __init__(self, llm: BaseChatModel, logger: Optional[Any] = None): |
22 | | - """ |
23 | | - Initialize the LangChain provider. |
24 | | - |
25 | | - :param llm: LangChain BaseChatModel instance |
26 | | - :param logger: Optional logger for logging provider operations |
27 | | - """ |
28 | | - super().__init__(logger) |
29 | | - self._llm = llm |
30 | | - |
31 | | - # ============================================================================= |
32 | | - # MAIN FACTORY METHOD |
33 | | - # ============================================================================= |
34 | | - |
35 | | - @staticmethod |
36 | | - async def create(ai_config: AIConfigKind, logger: Optional[Any] = None) -> 'LangChainProvider': |
37 | | - """ |
38 | | - Static factory method to create a LangChain AIProvider from an AI configuration. |
39 | | - |
40 | | - :param ai_config: The LaunchDarkly AI configuration |
41 | | - :param logger: Optional logger for the provider |
42 | | - :return: Configured LangChainProvider instance |
43 | | - """ |
44 | | - llm = await LangChainProvider.create_langchain_model(ai_config) |
45 | | - return LangChainProvider(llm, logger) |
46 | | - |
47 | | - # ============================================================================= |
48 | | - # INSTANCE METHODS (AIProvider Implementation) |
49 | | - # ============================================================================= |
50 | | - |
51 | | - async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: |
52 | | - """ |
53 | | - Invoke the LangChain model with an array of messages. |
54 | | - |
55 | | - :param messages: Array of LDMessage objects representing the conversation |
56 | | - :return: ChatResponse containing the model's response |
57 | | - """ |
58 | | - try: |
59 | | - # Convert LDMessage[] to LangChain messages |
60 | | - langchain_messages = LangChainProvider.convert_messages_to_langchain(messages) |
61 | | - |
62 | | - # Get the LangChain response |
63 | | - response: AIMessage = await self._llm.ainvoke(langchain_messages) |
64 | | - |
65 | | - # Generate metrics early (assumes success by default) |
66 | | - metrics = LangChainProvider.get_ai_metrics_from_response(response) |
67 | | - |
68 | | - # Extract text content from the response |
69 | | - content: str = '' |
70 | | - if isinstance(response.content, str): |
71 | | - content = response.content |
72 | | - else: |
73 | | - # Log warning for non-string content (likely multimodal) |
74 | | - if self.logger: |
75 | | - self.logger.warn( |
76 | | - f"Multimodal response not supported, expecting a string. " |
77 | | - f"Content type: {type(response.content)}, Content: {response.content}" |
78 | | - ) |
79 | | - # Update metrics to reflect content loss |
80 | | - metrics.success = False |
81 | | - |
82 | | - # Create the assistant message |
83 | | - from ldai.models import LDMessage |
84 | | - assistant_message = LDMessage(role='assistant', content=content) |
85 | | - |
86 | | - return ChatResponse( |
87 | | - message=assistant_message, |
88 | | - metrics=metrics, |
89 | | - ) |
90 | | - except Exception as error: |
91 | | - if self.logger: |
92 | | - self.logger.warn(f'LangChain model invocation failed: {error}') |
93 | | - |
94 | | - from ldai.models import LDMessage |
95 | | - return ChatResponse( |
96 | | - message=LDMessage(role='assistant', content=''), |
97 | | - metrics=LDAIMetrics(success=False, usage=None), |
98 | | - ) |
99 | | - |
100 | | - async def invoke_structured_model( |
101 | | - self, |
102 | | - messages: List[LDMessage], |
103 | | - response_structure: Dict[str, Any], |
104 | | - ) -> StructuredResponse: |
105 | | - """ |
106 | | - Invoke the LangChain model with structured output support. |
107 | | - |
108 | | - :param messages: Array of LDMessage objects representing the conversation |
109 | | - :param response_structure: Dictionary of output configurations keyed by output name |
110 | | - :return: StructuredResponse containing the structured data |
111 | | - """ |
112 | | - try: |
113 | | - # Convert LDMessage[] to LangChain messages |
114 | | - langchain_messages = LangChainProvider.convert_messages_to_langchain(messages) |
115 | | - |
116 | | - # Get the LangChain response with structured output |
117 | | - # Note: with_structured_output is available on BaseChatModel in newer LangChain versions |
118 | | - if hasattr(self._llm, 'with_structured_output'): |
119 | | - structured_llm = self._llm.with_structured_output(response_structure) |
120 | | - response = await structured_llm.ainvoke(langchain_messages) |
121 | | - else: |
122 | | - # Fallback: invoke normally and try to parse as JSON |
123 | | - response_obj = await self._llm.ainvoke(langchain_messages) |
124 | | - if isinstance(response_obj, AIMessage): |
125 | | - import json |
126 | | - try: |
127 | | - response = json.loads(response_obj.content) |
128 | | - except json.JSONDecodeError: |
129 | | - response = {'content': response_obj.content} |
130 | | - else: |
131 | | - response = response_obj |
132 | | - |
133 | | - # Using structured output doesn't support metrics |
134 | | - metrics = LDAIMetrics( |
135 | | - success=True, |
136 | | - usage=TokenUsage(total=0, input=0, output=0), |
137 | | - ) |
138 | | - |
139 | | - import json |
140 | | - return StructuredResponse( |
141 | | - data=response if isinstance(response, dict) else {'result': response}, |
142 | | - raw_response=json.dumps(response) if not isinstance(response, str) else response, |
143 | | - metrics=metrics, |
144 | | - ) |
145 | | - except Exception as error: |
146 | | - if self.logger: |
147 | | - self.logger.warn(f'LangChain structured model invocation failed: {error}') |
148 | | - |
149 | | - return StructuredResponse( |
150 | | - data={}, |
151 | | - raw_response='', |
152 | | - metrics=LDAIMetrics( |
153 | | - success=False, |
154 | | - usage=TokenUsage(total=0, input=0, output=0), |
155 | | - ), |
156 | | - ) |
157 | | - |
158 | | - def get_chat_model(self) -> BaseChatModel: |
159 | | - """ |
160 | | - Get the underlying LangChain model instance. |
161 | | - |
162 | | - :return: The LangChain BaseChatModel instance |
163 | | - """ |
164 | | - return self._llm |
165 | | - |
166 | | - # ============================================================================= |
167 | | - # STATIC UTILITY METHODS |
168 | | - # ============================================================================= |
169 | | - |
170 | | - @staticmethod |
171 | | - def map_provider(ld_provider_name: str) -> str: |
172 | | - """ |
173 | | - Map LaunchDarkly provider names to LangChain provider names. |
174 | | - |
175 | | - This method enables seamless integration between LaunchDarkly's standardized |
176 | | - provider naming and LangChain's naming conventions. |
177 | | - |
178 | | - :param ld_provider_name: LaunchDarkly provider name |
179 | | - :return: LangChain provider name |
180 | | - """ |
181 | | - lowercased_name = ld_provider_name.lower() |
182 | | - |
183 | | - mapping: Dict[str, str] = { |
184 | | - 'gemini': 'google-genai', |
185 | | - } |
186 | | - |
187 | | - return mapping.get(lowercased_name, lowercased_name) |
188 | | - |
189 | | - @staticmethod |
190 | | - def get_ai_metrics_from_response(response: AIMessage) -> LDAIMetrics: |
191 | | - """ |
192 | | - Get AI metrics from a LangChain provider response. |
193 | | - |
194 | | - This method extracts token usage information and success status from LangChain responses |
195 | | - and returns a LaunchDarkly LDAIMetrics object. |
196 | | - |
197 | | - :param response: The response from the LangChain model |
198 | | - :return: LDAIMetrics with success status and token usage |
199 | | - """ |
200 | | - # Extract token usage if available |
201 | | - usage: Optional[TokenUsage] = None |
202 | | - if hasattr(response, 'response_metadata') and response.response_metadata: |
203 | | - token_usage = response.response_metadata.get('token_usage') |
204 | | - if token_usage: |
205 | | - usage = TokenUsage( |
206 | | - total=token_usage.get('total_tokens', 0) or token_usage.get('totalTokens', 0) or 0, |
207 | | - input=token_usage.get('prompt_tokens', 0) or token_usage.get('promptTokens', 0) or 0, |
208 | | - output=token_usage.get('completion_tokens', 0) or token_usage.get('completionTokens', 0) or 0, |
209 | | - ) |
210 | | - |
211 | | - # LangChain responses that complete successfully are considered successful by default |
212 | | - return LDAIMetrics(success=True, usage=usage) |
213 | | - |
214 | | - @staticmethod |
215 | | - def convert_messages_to_langchain(messages: List[LDMessage]) -> List[BaseMessage]: |
216 | | - """ |
217 | | - Convert LaunchDarkly messages to LangChain messages. |
218 | | - |
219 | | - This helper method enables developers to work directly with LangChain message types |
220 | | - while maintaining compatibility with LaunchDarkly's standardized message format. |
221 | | - |
222 | | - :param messages: List of LDMessage objects |
223 | | - :return: List of LangChain message objects |
224 | | - """ |
225 | | - result: List[BaseMessage] = [] |
226 | | - for msg in messages: |
227 | | - if msg.role == 'system': |
228 | | - result.append(SystemMessage(content=msg.content)) |
229 | | - elif msg.role == 'user': |
230 | | - result.append(HumanMessage(content=msg.content)) |
231 | | - elif msg.role == 'assistant': |
232 | | - result.append(AIMessage(content=msg.content)) |
233 | | - else: |
234 | | - raise ValueError(f'Unsupported message role: {msg.role}') |
235 | | - return result |
236 | | - |
237 | | - @staticmethod |
238 | | - async def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel: |
239 | | - """ |
240 | | - Create a LangChain model from an AI configuration. |
241 | | - |
242 | | - This public helper method enables developers to initialize their own LangChain models |
243 | | - using LaunchDarkly AI configurations. |
244 | | - |
245 | | - :param ai_config: The LaunchDarkly AI configuration |
246 | | - :return: A configured LangChain BaseChatModel |
247 | | - """ |
248 | | - model_name = ai_config.model.name if ai_config.model else '' |
249 | | - provider = ai_config.provider.name if ai_config.provider else '' |
250 | | - parameters = ai_config.model.get_parameter('parameters') if ai_config.model else {} |
251 | | - if not isinstance(parameters, dict): |
252 | | - parameters = {} |
253 | | - |
254 | | - # Use LangChain's init_chat_model to support multiple providers |
255 | | - # Note: This requires langchain package to be installed |
256 | | - try: |
257 | | - # Try to import init_chat_model from langchain.chat_models |
258 | | - # This is available in langchain >= 0.1.0 |
259 | | - try: |
260 | | - from langchain.chat_models import init_chat_model |
261 | | - except ImportError: |
262 | | - # Fallback for older versions or different import path |
263 | | - from langchain.chat_models.universal import init_chat_model |
264 | | - |
265 | | - # Map provider name |
266 | | - langchain_provider = LangChainProvider.map_provider(provider) |
267 | | - |
268 | | - # Create model configuration |
269 | | - model_kwargs = {**parameters} |
270 | | - if langchain_provider: |
271 | | - model_kwargs['model_provider'] = langchain_provider |
272 | | - |
273 | | - # Initialize the chat model (init_chat_model may be async or sync) |
274 | | - result = init_chat_model(model_name, **model_kwargs) |
275 | | - # Handle both sync and async initialization |
276 | | - if hasattr(result, '__await__'): |
277 | | - return await result |
278 | | - return result |
279 | | - except ImportError as e: |
280 | | - raise ImportError( |
281 | | - 'langchain package is required for LangChainProvider. ' |
282 | | - 'Install it with: pip install langchain langchain-core' |
283 | | - ) from e |
| 5 | +__all__ = ['LangChainProvider'] |
284 | 6 |
|
0 commit comments