|
16 | 16 | import pytest |
17 | 17 | from pydantic import ValidationError |
18 | 18 |
|
19 | | -from nemoguardrails.rails.llm.config import TaskPrompt |
| 19 | +from nemoguardrails.rails.llm.config import ( |
| 20 | + Document, |
| 21 | + Instruction, |
| 22 | + Model, |
| 23 | + RailsConfig, |
| 24 | + TaskPrompt, |
| 25 | +) |
20 | 26 |
|
21 | 27 |
|
22 | 28 | def test_task_prompt_valid_content(): |
@@ -123,3 +129,181 @@ def test_task_prompt_max_tokens_validation(): |
123 | 129 | with pytest.raises(ValidationError) as excinfo: |
124 | 130 | TaskPrompt(task="example_task", content="Test prompt", max_tokens=-1) |
125 | 131 | assert "Input should be greater than or equal to 1" in str(excinfo.value) |
| 132 | + |
| 133 | + |
| 134 | +def test_rails_config_addition(): |
| 135 | + """Tests that adding two RailsConfig objects merges both into a single RailsConfig.""" |
| 136 | + config1 = RailsConfig( |
| 137 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 138 | + config_path="test_config.yml", |
| 139 | + ) |
| 140 | + config2 = RailsConfig( |
| 141 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 142 | + config_path="test_config2.yml", |
| 143 | + ) |
| 144 | + |
| 145 | + result = config1 + config2 |
| 146 | + |
| 147 | + assert isinstance(result, RailsConfig) |
| 148 | + assert len(result.models) == 2 |
| 149 | + assert result.config_path == "test_config.yml,test_config2.yml" |
| 150 | + |
| 151 | + |
| 152 | +def test_rails_config_model_conflicts(): |
| 153 | + """Tests that adding two RailsConfig objects with conflicting models raises an error.""" |
| 154 | + config1 = RailsConfig( |
| 155 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 156 | + config_path="config1.yml", |
| 157 | + ) |
| 158 | + |
| 159 | + # Different engine for same model type |
| 160 | + config2 = RailsConfig( |
| 161 | + models=[Model(type="main", engine="nim", model="gpt-3.5-turbo")], |
| 162 | + config_path="config2.yml", |
| 163 | + ) |
| 164 | + with pytest.raises( |
| 165 | + ValueError, |
| 166 | + match="Both config files should have the same engine for the same model type", |
| 167 | + ): |
| 168 | + config1 + config2 |
| 169 | + |
| 170 | + # Different model for same model type |
| 171 | + config3 = RailsConfig( |
| 172 | + models=[Model(type="main", engine="openai", model="gpt-4")], |
| 173 | + config_path="config3.yml", |
| 174 | + ) |
| 175 | + with pytest.raises( |
| 176 | + ValueError, |
| 177 | + match="Both config files should have the same model for the same model type", |
| 178 | + ): |
| 179 | + config1 + config3 |
| 180 | + |
| 181 | + |
| 182 | +def test_rails_config_actions_server_url_conflicts(): |
| 183 | + """Tests that adding two RailsConfig objects with different values for `actions_server_url` raises an error.""" |
| 184 | + config1 = RailsConfig( |
| 185 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 186 | + actions_server_url="http://localhost:8000", |
| 187 | + ) |
| 188 | + |
| 189 | + config2 = RailsConfig( |
| 190 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 191 | + actions_server_url="http://localhost:9000", |
| 192 | + ) |
| 193 | + |
| 194 | + with pytest.raises( |
| 195 | + ValueError, match="Both config files should have the same actions_server_url" |
| 196 | + ): |
| 197 | + config1 + config2 |
| 198 | + |
| 199 | + |
| 200 | +def test_rails_config_simple_field_overwriting(): |
| 201 | + """Tests that fields from the second config overwrite fields from the first config.""" |
| 202 | + config1 = RailsConfig( |
| 203 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 204 | + streaming=False, |
| 205 | + lowest_temperature=0.1, |
| 206 | + colang_version="1.0", |
| 207 | + ) |
| 208 | + |
| 209 | + config2 = RailsConfig( |
| 210 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 211 | + streaming=True, |
| 212 | + lowest_temperature=0.5, |
| 213 | + colang_version="2.x", |
| 214 | + ) |
| 215 | + |
| 216 | + result = config1 + config2 |
| 217 | + |
| 218 | + assert result.streaming is True |
| 219 | + assert result.lowest_temperature == 0.5 |
| 220 | + assert result.colang_version == "2.x" |
| 221 | + |
| 222 | + |
| 223 | +def test_rails_config_nested_dictionary_merging(): |
| 224 | + """Tests nested dictionaries are merged correctly.""" |
| 225 | + config1 = RailsConfig( |
| 226 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 227 | + rails={ |
| 228 | + "input": {"flows": ["flow1"], "parallel": False}, |
| 229 | + "output": {"flows": ["flow2"]}, |
| 230 | + }, |
| 231 | + knowledge_base={ |
| 232 | + "folder": "kb1", |
| 233 | + "embedding_search_provider": {"name": "provider1"}, |
| 234 | + }, |
| 235 | + custom_data={"setting1": "value1", "nested": {"key1": "val1"}}, |
| 236 | + ) |
| 237 | + |
| 238 | + config2 = RailsConfig( |
| 239 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 240 | + rails={ |
| 241 | + "input": {"flows": ["flow3"], "parallel": True}, |
| 242 | + "retrieval": {"flows": ["flow4"]}, |
| 243 | + }, |
| 244 | + knowledge_base={ |
| 245 | + "folder": "kb2", |
| 246 | + "embedding_search_provider": {"name": "provider2"}, |
| 247 | + }, |
| 248 | + custom_data={"setting2": "value2", "nested": {"key2": "val2"}}, |
| 249 | + ) |
| 250 | + |
| 251 | + result = config1 + config2 |
| 252 | + |
| 253 | + assert result.rails.input.flows == ["flow3", "flow1"] |
| 254 | + assert result.rails.input.parallel is True |
| 255 | + assert result.rails.output.flows == ["flow2"] |
| 256 | + assert result.rails.retrieval.flows == ["flow4"] |
| 257 | + |
| 258 | + assert result.knowledge_base.folder == "kb2" |
| 259 | + assert result.knowledge_base.embedding_search_provider.name == "provider2" |
| 260 | + |
| 261 | + assert result.custom_data["setting1"] == "value1" |
| 262 | + assert result.custom_data["setting2"] == "value2" |
| 263 | + assert result.custom_data["nested"]["key1"] == "val1" |
| 264 | + assert result.custom_data["nested"]["key2"] == "val2" |
| 265 | + |
| 266 | + |
| 267 | +def test_rails_config_none_prompts(): |
| 268 | + """Test that configs with None prompts can be added without errors.""" |
| 269 | + config1 = RailsConfig( |
| 270 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 271 | + prompts=None, |
| 272 | + rails={"input": {"flows": ["self_check_input"]}}, |
| 273 | + ) |
| 274 | + config2 = RailsConfig( |
| 275 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 276 | + prompts=[], |
| 277 | + ) |
| 278 | + |
| 279 | + result = config1 + config2 |
| 280 | + assert result is not None |
| 281 | + assert result.prompts is not None |
| 282 | + |
| 283 | + |
| 284 | +def test_rails_config_none_config_path(): |
| 285 | + """Test that configs with None config_path can be added.""" |
| 286 | + config1 = RailsConfig( |
| 287 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 288 | + config_path=None, |
| 289 | + ) |
| 290 | + config2 = RailsConfig( |
| 291 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 292 | + config_path="config2.yml", |
| 293 | + ) |
| 294 | + |
| 295 | + result = config1 + config2 |
| 296 | + # should not have leading comma after fix |
| 297 | + assert result.config_path == "config2.yml" |
| 298 | + |
| 299 | + config3 = RailsConfig( |
| 300 | + models=[Model(type="main", engine="openai", model="gpt-3.5-turbo")], |
| 301 | + config_path=None, |
| 302 | + ) |
| 303 | + config4 = RailsConfig( |
| 304 | + models=[Model(type="secondary", engine="anthropic", model="claude-3")], |
| 305 | + config_path=None, |
| 306 | + ) |
| 307 | + |
| 308 | + result2 = config3 + config4 |
| 309 | + assert result2.config_path == "" |
0 commit comments