|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import contextvars |
17 | | -from typing import Optional |
| 17 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union |
18 | 18 |
|
19 | | -streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None) |
| 19 | +from nemoguardrails.logging.explain import LLMCallInfo |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + from nemoguardrails.logging.explain import ExplainInfo |
| 23 | + from nemoguardrails.logging.stats import LLMStats |
| 24 | + from nemoguardrails.rails.llm.options import GenerationOptions |
| 25 | + from nemoguardrails.streaming import StreamingHandler |
| 26 | + |
| 27 | +streaming_handler_var: contextvars.ContextVar[ |
| 28 | + Optional["StreamingHandler"] |
| 29 | +] = contextvars.ContextVar("streaming_handler", default=None) |
20 | 30 |
|
21 | 31 | # The object that holds additional explanation information. |
22 | | -explain_info_var = contextvars.ContextVar("explain_info", default=None) |
| 32 | +explain_info_var: contextvars.ContextVar[ |
| 33 | + Optional["ExplainInfo"] |
| 34 | +] = contextvars.ContextVar("explain_info", default=None) |
23 | 35 |
|
24 | 36 | # The current LLM call. |
25 | | -llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None) |
| 37 | +llm_call_info_var: contextvars.ContextVar[ |
| 38 | + Optional[LLMCallInfo] |
| 39 | +] = contextvars.ContextVar("llm_call_info", default=None) |
26 | 40 |
|
27 | 41 | # All the generation options applicable to the current context. |
28 | | -generation_options_var = contextvars.ContextVar("generation_options", default=None) |
| 42 | +generation_options_var: contextvars.ContextVar[ |
| 43 | + Optional["GenerationOptions"] |
| 44 | +] = contextvars.ContextVar("generation_options", default=None) |
29 | 45 |
|
30 | 46 | # The stats about the LLM calls. |
31 | | -llm_stats_var = contextvars.ContextVar("llm_stats", default=None) |
| 47 | +llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar( |
| 48 | + "llm_stats", default=None |
| 49 | +) |
32 | 50 |
|
33 | 51 | # The raw LLM request that comes from the user. |
34 | 52 | # This is used in passthrough mode. |
35 | | -raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None) |
| 53 | +raw_llm_request: contextvars.ContextVar[ |
| 54 | + Optional[Union[str, List[Dict[str, Any]]]] |
| 55 | +] = contextvars.ContextVar("raw_llm_request", default=None) |
36 | 56 |
|
37 | 57 | reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( |
38 | 58 | "reasoning_trace", default=None |
|
0 commit comments