Skip to content

Commit 6430085

Browse files
authored
fix(config): ensure adding RailsConfig objects handles None values (#1328)
1 parent 6da8010 commit 6430085

File tree

2 files changed

+194
-9
lines changed

2 files changed

+194
-9
lines changed

nemoguardrails/rails/llm/config.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,12 +1350,13 @@ def check_reasoning_traces_with_dialog_rails(cls, values):
13501350
@root_validator(pre=True, allow_reuse=True)
13511351
def check_prompt_exist_for_self_check_rails(cls, values):
13521352
rails = values.get("rails", {})
1353+
prompts = values.get("prompts", []) or []
13531354

13541355
enabled_input_rails = rails.get("input", {}).get("flows", [])
13551356
enabled_output_rails = rails.get("output", {}).get("flows", [])
13561357
provided_task_prompts = [
13571358
prompt.task if hasattr(prompt, "task") else prompt.get("task")
1358-
for prompt in values.get("prompts", [])
1359+
for prompt in prompts
13591360
]
13601361

13611362
# Input moderation prompt verification
@@ -1410,7 +1411,7 @@ def check_output_parser_exists(cls, values):
14101411
# "content_safety_check input $model",
14111412
# "content_safety_check output $model",
14121413
]
1413-
prompts = values.get("prompts", [])
1414+
prompts = values.get("prompts") or []
14141415
for prompt in prompts:
14151416
task = prompt.task if hasattr(prompt, "task") else prompt.get("task")
14161417
output_parser = (
@@ -1657,12 +1658,12 @@ def _join_rails_configs(
16571658
combined_rails_config_dict = _join_dict(
16581659
base_rails_config.dict(), updated_rails_config.dict()
16591660
)
1660-
combined_rails_config_dict["config_path"] = ",".join(
1661-
[
1662-
base_rails_config.dict()["config_path"],
1663-
updated_rails_config.dict()["config_path"],
1664-
]
1665-
)
1661+
# filter out empty strings to avoid leading/trailing commas
1662+
config_paths = [
1663+
base_rails_config.dict()["config_path"] or "",
1664+
updated_rails_config.dict()["config_path"] or "",
1665+
]
1666+
combined_rails_config_dict["config_path"] = ",".join(filter(None, config_paths))
16661667
combined_rails_config = RailsConfig(**combined_rails_config_dict)
16671668
return combined_rails_config
16681669

tests/rails/llm/test_config.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
import pytest
1717
from pydantic import ValidationError
1818

19-
from nemoguardrails.rails.llm.config import TaskPrompt
19+
from nemoguardrails.rails.llm.config import (
20+
Document,
21+
Instruction,
22+
Model,
23+
RailsConfig,
24+
TaskPrompt,
25+
)
2026

2127

2228
def test_task_prompt_valid_content():
@@ -123,3 +129,181 @@ def test_task_prompt_max_tokens_validation():
123129
with pytest.raises(ValidationError) as excinfo:
124130
TaskPrompt(task="example_task", content="Test prompt", max_tokens=-1)
125131
assert "Input should be greater than or equal to 1" in str(excinfo.value)
132+
133+
134+
def test_rails_config_addition():
135+
"""Tests that adding two RailsConfig objects merges both into a single RailsConfig."""
136+
config1 = RailsConfig(
137+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
138+
config_path="test_config.yml",
139+
)
140+
config2 = RailsConfig(
141+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
142+
config_path="test_config2.yml",
143+
)
144+
145+
result = config1 + config2
146+
147+
assert isinstance(result, RailsConfig)
148+
assert len(result.models) == 2
149+
assert result.config_path == "test_config.yml,test_config2.yml"
150+
151+
152+
def test_rails_config_model_conflicts():
153+
"""Tests that adding two RailsConfig objects with conflicting models raises an error."""
154+
config1 = RailsConfig(
155+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
156+
config_path="config1.yml",
157+
)
158+
159+
# Different engine for same model type
160+
config2 = RailsConfig(
161+
models=[Model(type="main", engine="nim", model="gpt-3.5-turbo")],
162+
config_path="config2.yml",
163+
)
164+
with pytest.raises(
165+
ValueError,
166+
match="Both config files should have the same engine for the same model type",
167+
):
168+
config1 + config2
169+
170+
# Different model for same model type
171+
config3 = RailsConfig(
172+
models=[Model(type="main", engine="openai", model="gpt-4")],
173+
config_path="config3.yml",
174+
)
175+
with pytest.raises(
176+
ValueError,
177+
match="Both config files should have the same model for the same model type",
178+
):
179+
config1 + config3
180+
181+
182+
def test_rails_config_actions_server_url_conflicts():
183+
"""Tests that adding two RailsConfig objects with different values for `actions_server_url` raises an error."""
184+
config1 = RailsConfig(
185+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
186+
actions_server_url="http://localhost:8000",
187+
)
188+
189+
config2 = RailsConfig(
190+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
191+
actions_server_url="http://localhost:9000",
192+
)
193+
194+
with pytest.raises(
195+
ValueError, match="Both config files should have the same actions_server_url"
196+
):
197+
config1 + config2
198+
199+
200+
def test_rails_config_simple_field_overwriting():
201+
"""Tests that fields from the second config overwrite fields from the first config."""
202+
config1 = RailsConfig(
203+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
204+
streaming=False,
205+
lowest_temperature=0.1,
206+
colang_version="1.0",
207+
)
208+
209+
config2 = RailsConfig(
210+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
211+
streaming=True,
212+
lowest_temperature=0.5,
213+
colang_version="2.x",
214+
)
215+
216+
result = config1 + config2
217+
218+
assert result.streaming is True
219+
assert result.lowest_temperature == 0.5
220+
assert result.colang_version == "2.x"
221+
222+
223+
def test_rails_config_nested_dictionary_merging():
224+
"""Tests nested dictionaries are merged correctly."""
225+
config1 = RailsConfig(
226+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
227+
rails={
228+
"input": {"flows": ["flow1"], "parallel": False},
229+
"output": {"flows": ["flow2"]},
230+
},
231+
knowledge_base={
232+
"folder": "kb1",
233+
"embedding_search_provider": {"name": "provider1"},
234+
},
235+
custom_data={"setting1": "value1", "nested": {"key1": "val1"}},
236+
)
237+
238+
config2 = RailsConfig(
239+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
240+
rails={
241+
"input": {"flows": ["flow3"], "parallel": True},
242+
"retrieval": {"flows": ["flow4"]},
243+
},
244+
knowledge_base={
245+
"folder": "kb2",
246+
"embedding_search_provider": {"name": "provider2"},
247+
},
248+
custom_data={"setting2": "value2", "nested": {"key2": "val2"}},
249+
)
250+
251+
result = config1 + config2
252+
253+
assert result.rails.input.flows == ["flow3", "flow1"]
254+
assert result.rails.input.parallel is True
255+
assert result.rails.output.flows == ["flow2"]
256+
assert result.rails.retrieval.flows == ["flow4"]
257+
258+
assert result.knowledge_base.folder == "kb2"
259+
assert result.knowledge_base.embedding_search_provider.name == "provider2"
260+
261+
assert result.custom_data["setting1"] == "value1"
262+
assert result.custom_data["setting2"] == "value2"
263+
assert result.custom_data["nested"]["key1"] == "val1"
264+
assert result.custom_data["nested"]["key2"] == "val2"
265+
266+
267+
def test_rails_config_none_prompts():
268+
"""Test that configs with None prompts can be added without errors."""
269+
config1 = RailsConfig(
270+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
271+
prompts=None,
272+
rails={"input": {"flows": ["self_check_input"]}},
273+
)
274+
config2 = RailsConfig(
275+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
276+
prompts=[],
277+
)
278+
279+
result = config1 + config2
280+
assert result is not None
281+
assert result.prompts is not None
282+
283+
284+
def test_rails_config_none_config_path():
285+
"""Test that configs with None config_path can be added."""
286+
config1 = RailsConfig(
287+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
288+
config_path=None,
289+
)
290+
config2 = RailsConfig(
291+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
292+
config_path="config2.yml",
293+
)
294+
295+
result = config1 + config2
296+
# should not have leading comma after fix
297+
assert result.config_path == "config2.yml"
298+
299+
config3 = RailsConfig(
300+
models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")],
301+
config_path=None,
302+
)
303+
config4 = RailsConfig(
304+
models=[Model(type="secondary", engine="anthropic", model="claude-3")],
305+
config_path=None,
306+
)
307+
308+
result2 = config3 + config4
309+
assert result2.config_path == ""

0 commit comments

Comments
 (0)