|
15 | 15 |
|
16 | 16 | import os
|
17 | 17 | from typing import Any, Dict, List, Optional, Union
|
18 |
| -from unittest.mock import patch |
| 18 | +from unittest.mock import MagicMock, patch |
19 | 19 |
|
20 | 20 | import pytest
|
| 21 | +from langchain_core.language_models import BaseChatModel |
21 | 22 |
|
22 | 23 | from nemoguardrails import LLMRails, RailsConfig
|
| 24 | +from nemoguardrails.logging.explain import ExplainInfo |
23 | 25 | from nemoguardrails.rails.llm.config import Model
|
24 | 26 | from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id
|
25 | 27 | from tests.utils import FakeLLM, clean_events, event_sequence_conforms
|
@@ -1170,3 +1172,18 @@ def dummy_parser(text):
|
1170 | 1172 | assert "chained_action" in rails.runtime.action_dispatcher.registered_actions
|
1171 | 1173 | assert "chained_param" in rails.runtime.registered_action_params
|
1172 | 1174 | assert rails.runtime.registered_action_params["chained_param"] == "param_value"
|
| 1175 | + |
| 1176 | + |
| 1177 | +def test_explain_calls_ensure_explain_info(): |
| 1178 | + """Make sure if no `explain_info` attribute is present in LLMRails it's populated with |
| 1179 | + an empty ExplainInfo object""" |
| 1180 | + |
| 1181 | + mock_llm = MagicMock(spec=BaseChatModel) |
| 1182 | + config = RailsConfig.from_content(config={"models": []}) |
| 1183 | + rails = LLMRails(config=config, llm=mock_llm) |
| 1184 | + rails.generate(messages=[{"role": "user", "content": "Hi!"}]) |
| 1185 | + |
| 1186 | + rails.explain_info = None |
| 1187 | + info = rails.explain() |
| 1188 | + assert info == ExplainInfo() |
| 1189 | + assert rails.explain_info == ExplainInfo() |
0 commit comments