@@ -1155,3 +1155,100 @@ async def test_stream_usage_enabled_for_all_providers_when_streaming(
11551155
11561156 # stream_usage should be set for all providers when streaming is enabled
11571157 assert kwargs .get ("stream_usage" ) is True
1158+
1159+
1160+ # Add this test after the existing tests, around line 1100+
1161+
1162+
1163+ def test_register_methods_return_self ():
1164+ """Test that all register_* methods return self for method chaining."""
1165+ config = RailsConfig .from_content (config = {"models" : []})
1166+ rails = LLMRails (config = config , llm = FakeLLM (responses = []))
1167+
1168+ # Test register_action returns self
1169+ def dummy_action ():
1170+ pass
1171+
1172+ result = rails .register_action (dummy_action , "test_action" )
1173+ assert result is rails , "register_action should return self"
1174+
1175+ # Test register_action_param returns self
1176+ result = rails .register_action_param ("test_param" , "test_value" )
1177+ assert result is rails , "register_action_param should return self"
1178+
1179+ # Test register_filter returns self
1180+ def dummy_filter (text ):
1181+ return text
1182+
1183+ result = rails .register_filter (dummy_filter , "test_filter" )
1184+ assert result is rails , "register_filter should return self"
1185+
1186+ # Test register_output_parser returns self
1187+ def dummy_parser (text ):
1188+ return text
1189+
1190+ result = rails .register_output_parser (dummy_parser , "test_parser" )
1191+ assert result is rails , "register_output_parser should return self"
1192+
1193+ # Test register_prompt_context returns self
1194+ result = rails .register_prompt_context ("test_context" , "test_value" )
1195+ assert result is rails , "register_prompt_context should return self"
1196+
1197+ # Test register_embedding_search_provider returns self
1198+ from nemoguardrails .embeddings .index import EmbeddingsIndex
1199+
1200+ class DummyEmbeddingProvider (EmbeddingsIndex ):
1201+ def __init__ (self , ** kwargs ):
1202+ pass
1203+
1204+ def build (self ):
1205+ pass
1206+
1207+ def search (self , text , max_results = 5 ):
1208+ return []
1209+
1210+ result = rails .register_embedding_search_provider (
1211+ "dummy_provider" , DummyEmbeddingProvider
1212+ )
1213+ assert result is rails , "register_embedding_search_provider should return self"
1214+
1215+ # Test register_embedding_provider returns self
1216+ from nemoguardrails .embeddings .providers .base import EmbeddingModel
1217+
1218+ class DummyEmbeddingModel (EmbeddingModel ):
1219+ def encode (self , texts ):
1220+ return []
1221+
1222+ result = rails .register_embedding_provider (DummyEmbeddingModel , "dummy_embedding" )
1223+ assert result is rails , "register_embedding_provider should return self"
1224+
1225+
1226+ def test_method_chaining ():
1227+ """Test that method chaining works correctly with register_* methods."""
1228+ config = RailsConfig .from_content (config = {"models" : []})
1229+ rails = LLMRails (config = config , llm = FakeLLM (responses = []))
1230+
1231+ def dummy_action ():
1232+ return "action_result"
1233+
1234+ def dummy_filter (text ):
1235+ return text .upper ()
1236+
1237+ def dummy_parser (text ):
1238+ return {"parsed" : text }
1239+
1240+ # Test chaining multiple register methods
1241+ result = (
1242+ rails .register_action (dummy_action , "chained_action" )
1243+ .register_action_param ("chained_param" , "param_value" )
1244+ .register_filter (dummy_filter , "chained_filter" )
1245+ .register_output_parser (dummy_parser , "chained_parser" )
1246+ .register_prompt_context ("chained_context" , "context_value" )
1247+ )
1248+
1249+ assert result is rails , "Method chaining should return the same rails instance"
1250+
1251+ # Verify that all registrations actually worked
1252+ assert "chained_action" in rails .runtime .action_dispatcher .registered_actions
1253+ assert "chained_param" in rails .runtime .registered_action_params
1254+ assert rails .runtime .registered_action_params ["chained_param" ] == "param_value"
0 commit comments