Skip to content

Commit ca55708

Browse files
committed
test(cli): add comprehensive CLI test suite and reorganize files
1 parent 949e422 commit ca55708

File tree

7 files changed

+1870
-316
lines changed

7 files changed

+1870
-316
lines changed

tests/cli/test_chat.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
import sys
18+
from unittest.mock import AsyncMock, MagicMock, patch
19+
20+
import pytest
21+
22+
from nemoguardrails.cli.chat import (
23+
ChatState,
24+
extract_scene_text_content,
25+
parse_events_inputs,
26+
run_chat,
27+
)
28+
29+
chat_module = sys.modules["nemoguardrails.cli.chat"]
30+
31+
32+
class TestParseEventsInputs:
33+
def test_parse_simple_event(self):
34+
result = parse_events_inputs("UserAction")
35+
assert result == {"type": "UserAction"}
36+
37+
def test_parse_event_with_params(self):
38+
result = parse_events_inputs('UserAction(name="test", value=123)')
39+
assert result == {"type": "UserAction", "name": "test", "value": 123}
40+
41+
def test_parse_event_with_string_params(self):
42+
result = parse_events_inputs('UserAction(message="hello world")')
43+
assert result == {"type": "UserAction", "message": "hello world"}
44+
45+
def test_parse_nested_event(self):
46+
result = parse_events_inputs("bot.UtteranceAction")
47+
assert result == {"type": "botUtteranceAction"}
48+
49+
def test_parse_event_with_nested_params(self):
50+
result = parse_events_inputs('UserAction(data={"key": "value"})')
51+
assert result == {"type": "UserAction", "data": {"key": "value"}}
52+
53+
def test_parse_event_with_list_params(self):
54+
result = parse_events_inputs("UserAction(items=[1, 2, 3])")
55+
assert result == {"type": "UserAction", "items": [1, 2, 3]}
56+
57+
def test_parse_invalid_event(self):
58+
result = parse_events_inputs("Invalid.Event.Format.TooMany")
59+
assert result is None
60+
61+
def test_parse_event_missing_equals(self):
62+
result = parse_events_inputs("UserAction(invalid_param)")
63+
assert result is None
64+
65+
66+
class TestExtractSceneTextContent:
67+
def test_extract_empty_list(self):
68+
result = extract_scene_text_content([])
69+
assert result == ""
70+
71+
def test_extract_single_text(self):
72+
content = [{"text": "Hello World"}]
73+
result = extract_scene_text_content(content)
74+
assert result == "\nHello World"
75+
76+
def test_extract_multiple_texts(self):
77+
content = [{"text": "Line 1"}, {"text": "Line 2"}, {"text": "Line 3"}]
78+
result = extract_scene_text_content(content)
79+
assert result == "\nLine 1\nLine 2\nLine 3"
80+
81+
def test_extract_mixed_content(self):
82+
content = [
83+
{"text": "Text 1"},
84+
{"image": "image.png"},
85+
{"text": "Text 2"},
86+
{"button": "Click Me"},
87+
]
88+
result = extract_scene_text_content(content)
89+
assert result == "\nText 1\nText 2"
90+
91+
def test_extract_no_text_content(self):
92+
content = [{"image": "image.png"}, {"button": "Click Me"}]
93+
result = extract_scene_text_content(content)
94+
assert result == ""
95+
96+
97+
class TestChatState:
98+
def test_initial_state(self):
99+
chat_state = ChatState()
100+
assert chat_state.state is None
101+
assert chat_state.waiting_user_input is False
102+
assert chat_state.paused is False
103+
assert chat_state.running_timer_tasks == {}
104+
assert chat_state.input_events == []
105+
assert chat_state.output_events == []
106+
assert chat_state.output_state is None
107+
assert chat_state.events_counter == 0
108+
assert chat_state.first_time is False
109+
110+
111+
class TestRunChat:
112+
@patch("asyncio.run")
113+
@patch.object(chat_module, "LLMRails")
114+
@patch.object(chat_module, "RailsConfig")
115+
def test_run_chat_v1_0(self, mock_rails_config, mock_llm_rails, mock_asyncio_run):
116+
mock_config = MagicMock()
117+
mock_config.colang_version = "1.0"
118+
mock_rails_config.from_path.return_value = mock_config
119+
120+
run_chat(config_path="test_config")
121+
122+
mock_rails_config.from_path.assert_called_once_with("test_config")
123+
mock_asyncio_run.assert_called_once()
124+
125+
@patch.object(chat_module, "get_or_create_event_loop")
126+
@patch.object(chat_module, "LLMRails")
127+
@patch.object(chat_module, "RailsConfig")
128+
def test_run_chat_v2_x(self, mock_rails_config, mock_llm_rails, mock_get_loop):
129+
mock_config = MagicMock()
130+
mock_config.colang_version = "2.x"
131+
mock_rails_config.from_path.return_value = mock_config
132+
133+
mock_loop = MagicMock()
134+
mock_get_loop.return_value = mock_loop
135+
136+
run_chat(config_path="test_config")
137+
138+
mock_rails_config.from_path.assert_called_once_with("test_config")
139+
mock_llm_rails.assert_called_once_with(mock_config, verbose=False)
140+
mock_loop.run_until_complete.assert_called_once()
141+
142+
@patch.object(chat_module, "RailsConfig")
143+
def test_run_chat_invalid_version(self, mock_rails_config):
144+
mock_config = MagicMock()
145+
mock_config.colang_version = "3.0"
146+
mock_rails_config.from_path.return_value = mock_config
147+
148+
with pytest.raises(Exception, match="Invalid colang version"):
149+
run_chat(config_path="test_config")
150+
151+
@patch.object(chat_module, "console")
152+
@patch("asyncio.run")
153+
@patch.object(chat_module, "RailsConfig")
154+
def test_run_chat_verbose_with_llm_calls(
155+
self, mock_rails_config, mock_asyncio_run, mock_console
156+
):
157+
mock_config = MagicMock()
158+
mock_config.colang_version = "1.0"
159+
mock_rails_config.from_path.return_value = mock_config
160+
161+
run_chat(config_path="test_config", verbose=True, verbose_llm_calls=True)
162+
163+
mock_console.print.assert_any_call(
164+
"NOTE: use the `--verbose-no-llm` option to exclude the LLM prompts "
165+
"and completions from the log.\n"
166+
)
167+
168+
169+
class TestRunChatV1Async:
170+
@pytest.mark.asyncio
171+
async def test_run_chat_v1_no_config_no_server(self):
172+
from nemoguardrails.cli.chat import _run_chat_v1_0
173+
174+
with pytest.raises(RuntimeError, match="At least one of"):
175+
await _run_chat_v1_0(config_path=None, server_url=None)
176+
177+
@pytest.mark.asyncio
178+
@patch("builtins.input")
179+
@patch.object(chat_module, "LLMRails")
180+
@patch.object(chat_module, "RailsConfig")
181+
async def test_run_chat_v1_local_config(
182+
self, mock_rails_config, mock_llm_rails, mock_input
183+
):
184+
from nemoguardrails.cli.chat import _run_chat_v1_0
185+
186+
mock_config = MagicMock()
187+
mock_config.streaming_supported = False
188+
mock_rails_config.from_path.return_value = mock_config
189+
190+
mock_rails = AsyncMock()
191+
mock_rails.generate_async = AsyncMock(
192+
return_value={"role": "assistant", "content": "Hello!"}
193+
)
194+
mock_rails.main_llm_supports_streaming = False
195+
mock_llm_rails.return_value = mock_rails
196+
197+
mock_input.side_effect = ["test message", KeyboardInterrupt()]
198+
199+
try:
200+
await _run_chat_v1_0(config_path="test_config")
201+
except KeyboardInterrupt:
202+
pass
203+
204+
mock_rails.generate_async.assert_called_once()
205+
206+
@pytest.mark.asyncio
207+
@patch("builtins.input")
208+
@patch.object(chat_module, "console")
209+
@patch.object(chat_module, "LLMRails")
210+
@patch.object(chat_module, "RailsConfig")
211+
async def test_run_chat_v1_streaming_not_supported(
212+
self, mock_rails_config, mock_llm_rails, mock_console, mock_input
213+
):
214+
from nemoguardrails.cli.chat import _run_chat_v1_0
215+
216+
mock_config = MagicMock()
217+
mock_config.streaming_supported = False
218+
mock_rails_config.from_path.return_value = mock_config
219+
220+
mock_rails = AsyncMock()
221+
mock_llm_rails.return_value = mock_rails
222+
223+
mock_input.side_effect = [KeyboardInterrupt()]
224+
225+
try:
226+
await _run_chat_v1_0(config_path="test_config", streaming=True)
227+
except KeyboardInterrupt:
228+
pass
229+
230+
mock_console.print.assert_any_call(
231+
"WARNING: The config `test_config` does not support streaming. "
232+
"Falling back to normal mode."
233+
)
234+
235+
@pytest.mark.asyncio
236+
@patch("aiohttp.ClientSession")
237+
@patch("builtins.input")
238+
async def test_run_chat_v1_server_mode(self, mock_input, mock_client_session):
239+
from nemoguardrails.cli.chat import _run_chat_v1_0
240+
241+
mock_session = AsyncMock()
242+
mock_response = AsyncMock()
243+
mock_response.headers = {}
244+
mock_response.json = AsyncMock(
245+
return_value={
246+
"messages": [{"role": "assistant", "content": "Server response"}]
247+
}
248+
)
249+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
250+
mock_response.__aexit__ = AsyncMock()
251+
252+
mock_post_context = AsyncMock()
253+
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
254+
mock_post_context.__aexit__ = AsyncMock()
255+
mock_session.post = MagicMock(return_value=mock_post_context)
256+
257+
mock_client_session.return_value.__aenter__ = AsyncMock(
258+
return_value=mock_session
259+
)
260+
mock_client_session.return_value.__aexit__ = AsyncMock()
261+
262+
mock_input.side_effect = ["test message", KeyboardInterrupt()]
263+
264+
try:
265+
await _run_chat_v1_0(
266+
server_url="http://localhost:8000", config_id="test_id"
267+
)
268+
except KeyboardInterrupt:
269+
pass
270+
271+
assert mock_session.post.called
272+
call_args = mock_session.post.call_args
273+
assert call_args[0][0] == "http://localhost:8000/v1/chat/completions"
274+
assert "config_id" in call_args[1]["json"]
275+
assert call_args[1]["json"]["config_id"] == "test_id"
276+
assert call_args[1]["json"]["stream"] is False
277+
278+
@pytest.mark.asyncio
279+
@patch("aiohttp.ClientSession")
280+
@patch("builtins.input")
281+
async def test_run_chat_v1_server_streaming(self, mock_input, mock_client_session):
282+
from nemoguardrails.cli.chat import _run_chat_v1_0
283+
284+
mock_session = AsyncMock()
285+
mock_response = AsyncMock()
286+
mock_response.headers = {"Transfer-Encoding": "chunked"}
287+
288+
async def mock_iter_any():
289+
yield b"Stream "
290+
yield b"response"
291+
292+
mock_response.content.iter_any = mock_iter_any
293+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
294+
mock_response.__aexit__ = AsyncMock()
295+
296+
mock_post_context = AsyncMock()
297+
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
298+
mock_post_context.__aexit__ = AsyncMock()
299+
mock_session.post = MagicMock(return_value=mock_post_context)
300+
301+
mock_client_session.return_value.__aenter__ = AsyncMock(
302+
return_value=mock_session
303+
)
304+
mock_client_session.return_value.__aexit__ = AsyncMock()
305+
306+
mock_input.side_effect = ["test message", KeyboardInterrupt()]
307+
308+
try:
309+
await _run_chat_v1_0(
310+
server_url="http://localhost:8000", config_id="test_id", streaming=True
311+
)
312+
except KeyboardInterrupt:
313+
pass
314+
315+
assert mock_session.post.called

0 commit comments

Comments
 (0)