Skip to content

Commit 5250349

Browse files
committed
test: add tests for model not persistant models
1 parent 9e79e91 commit 5250349

File tree

1 file changed

+241
-84
lines changed

1 file changed

+241
-84
lines changed

tests/unit/distribution/routers/test_routing_tables.py

Lines changed: 241 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from llama_stack.apis.shields.shields import Shield
1818
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
1919
from llama_stack.apis.vector_dbs import VectorDB
20-
from llama_stack.core.datatypes import RegistryEntrySource
20+
from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig
2121
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
2222
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
2323
from llama_stack.core.routing_tables.models import ModelsRoutingTable
@@ -534,114 +534,271 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry):
534534
await table.shutdown()
535535

536536

537-
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
538-
"""Test that provider refresh preserves user-registered models with default source."""
539-
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
537+
async def test_models_dynamic_from_config_generation(cached_disk_dist_registry):
538+
"""Test that from_config models are generated dynamically from run_config."""
539+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
540540
await table.initialize()
541541

542-
# First register a user model with same provider_resource_id as provider will later provide
543-
await table.register_model(
544-
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
542+
# Test that no from_config models are registered when no run_config
543+
all_models = await table.get_all_with_type("model")
544+
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
545+
assert len(from_config_models) == 0
546+
547+
# Create a run config with from_config models
548+
run_config = StackRunConfig(
549+
image_name="test",
550+
providers={},
551+
models=[
552+
{
553+
"model_id": "from_config_model_1",
554+
"provider_id": "test_provider",
555+
"model_type": "llm",
556+
"provider_model_id": "gpt-3.5-turbo",
557+
},
558+
{
559+
"model_id": "from_config_model_2",
560+
"provider_id": "test_provider",
561+
"model_type": "llm",
562+
"provider_model_id": "gpt-4",
563+
},
564+
],
545565
)
546566

547-
# Verify user model is registered with default source
548-
models = await table.list_models()
549-
assert len(models.data) == 1
550-
user_model = models.data[0]
551-
assert user_model.source == RegistryEntrySource.via_register_api
552-
assert user_model.identifier == "my-custom-alias"
553-
assert user_model.provider_resource_id == "provider-model-1"
567+
# Set the run config
568+
table.current_run_config = run_config
569+
await table.cleanup_disabled_provider_models()
570+
await table.register_from_config_models()
554571

555-
# Now simulate provider refresh
556-
provider_models = [
557-
Model(
558-
identifier="provider-model-1",
559-
provider_resource_id="provider-model-1",
560-
provider_id="test_provider",
561-
metadata={},
562-
model_type=ModelType.llm,
563-
),
564-
Model(
565-
identifier="different-model",
566-
provider_resource_id="different-model",
567-
provider_id="test_provider",
568-
metadata={},
569-
model_type=ModelType.llm,
570-
),
571-
]
572-
await table.update_registered_models("test_provider", provider_models)
572+
# Test that from_config models are registered in the registry
573+
all_models = await table.get_all_with_type("model")
574+
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
575+
assert len(from_config_models) == 2
573576

574-
# Verify user model with alias is preserved, but provider added new model
575-
models = await table.list_models()
576-
assert len(models.data) == 2
577+
model_identifiers = {m.identifier for m in from_config_models}
578+
assert "from_config_model_1" in model_identifiers
579+
assert "from_config_model_2" in model_identifiers
580+
581+
# Test that from_config models have correct attributes
582+
model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1")
583+
assert model_1.provider_id == "test_provider"
584+
assert model_1.provider_resource_id == "gpt-3.5-turbo"
585+
assert model_1.model_type == ModelType.llm
586+
assert model_1.source == RegistryEntrySource.from_config
587+
588+
# Cleanup
589+
await table.shutdown()
577590

578-
# Find the user model and provider model
579-
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
580-
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
581591

582-
assert user_model is not None
583-
assert user_model.source == RegistryEntrySource.via_register_api
584-
assert user_model.provider_resource_id == "provider-model-1"
592+
async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry):
593+
"""Test that from_config models can be looked up individually."""
594+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
595+
await table.initialize()
596+
597+
# Create a run config with from_config models
598+
run_config = StackRunConfig(
599+
image_name="test",
600+
providers={},
601+
models=[
602+
{
603+
"model_id": "lookup_test_model",
604+
"provider_id": "test_provider",
605+
"model_type": "llm",
606+
"provider_model_id": "gpt-3.5-turbo",
607+
}
608+
],
609+
)
585610

586-
assert provider_model is not None
587-
assert provider_model.source == RegistryEntrySource.listed_from_provider
588-
assert provider_model.provider_resource_id == "different-model"
611+
# Set the run config
612+
table.current_run_config = run_config
613+
await table.cleanup_disabled_provider_models()
614+
await table.register_from_config_models()
615+
616+
# Test that we can get the from_config model individually
617+
model = await table.get_model("lookup_test_model")
618+
assert model is not None
619+
assert model.identifier == "lookup_test_model"
620+
assert model.provider_id == "test_provider"
621+
assert model.provider_resource_id == "gpt-3.5-turbo"
622+
assert model.source == RegistryEntrySource.from_config
589623

590624
# Cleanup
591625
await table.shutdown()
592626

593627

594-
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
595-
"""Test that provider refresh removes old provider models but keeps default ones."""
596-
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
628+
async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry):
629+
"""Test that from_config models work alongside persistent models."""
630+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
597631
await table.initialize()
598632

599-
# Register a user model
600-
await table.register_model(model_id="user-model", provider_id="test_provider")
633+
# Create a run config with from_config models
634+
run_config = StackRunConfig(
635+
image_name="test",
636+
providers={},
637+
models=[
638+
{
639+
"model_id": "from_config_model",
640+
"provider_id": "test_provider",
641+
"model_type": "llm",
642+
"provider_model_id": "gpt-3.5-turbo",
643+
}
644+
],
645+
)
601646

602-
# Add some provider models
603-
provider_models_v1 = [
604-
Model(
605-
identifier="provider-model-old",
606-
provider_resource_id="provider-model-old",
607-
provider_id="test_provider",
608-
metadata={},
609-
model_type=ModelType.llm,
610-
),
611-
]
612-
await table.update_registered_models("test_provider", provider_models_v1)
647+
# Set the run config
648+
table.current_run_config = run_config
649+
await table.cleanup_disabled_provider_models()
650+
await table.register_from_config_models()
613651

614-
# Verify we have both user and provider models
652+
# Test that from_config models are included
615653
models = await table.list_models()
616-
assert len(models.data) == 2
654+
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
617655

618-
# Now update with new provider models (should remove old provider models)
619-
provider_models_v2 = [
620-
Model(
621-
identifier="provider-model-new",
622-
provider_resource_id="provider-model-new",
623-
provider_id="test_provider",
624-
metadata={},
625-
model_type=ModelType.llm,
626-
),
627-
]
628-
await table.update_registered_models("test_provider", provider_models_v2)
656+
assert len(from_config_models) == 1
657+
assert from_config_models[0].identifier == "from_config_model"
658+
659+
# Test that we can get the from_config model individually
660+
from_config_model = await table.get_model("from_config_model")
661+
assert from_config_model is not None
662+
assert from_config_model.source == RegistryEntrySource.from_config
663+
664+
# Cleanup
665+
await table.shutdown()
666+
667+
668+
async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry):
669+
"""Test that from_config models with disabled providers are skipped."""
670+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
671+
await table.initialize()
672+
673+
# Create a run config with disabled provider models
674+
run_config = StackRunConfig(
675+
image_name="test",
676+
providers={},
677+
models=[
678+
{
679+
"model_id": "enabled_model",
680+
"provider_id": "test_provider",
681+
"model_type": "llm",
682+
"provider_model_id": "gpt-3.5-turbo",
683+
},
684+
{
685+
"model_id": "disabled_model",
686+
"provider_id": "__disabled__",
687+
"model_type": "llm",
688+
"provider_model_id": "gpt-4",
689+
},
690+
],
691+
)
692+
693+
# Set the run config
694+
table.current_run_config = run_config
695+
await table.cleanup_disabled_provider_models()
696+
await table.register_from_config_models()
697+
698+
# Test that only enabled models are included
699+
all_models = await table.get_all_with_type("model")
700+
from_config_models = [m for m in all_models if m.source == RegistryEntrySource.from_config]
701+
assert len(from_config_models) == 1
702+
assert from_config_models[0].identifier == "enabled_model"
703+
704+
# Cleanup
705+
await table.shutdown()
706+
707+
708+
async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry):
709+
"""Test that from_config models work when no run_config is set."""
710+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
711+
await table.initialize()
629712

630-
# Should have user model + new provider model, old provider model gone
713+
# Test that list_models works without run_config
631714
models = await table.list_models()
632-
assert len(models.data) == 2
715+
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
716+
assert len(from_config_models) == 0 # No from_config models when no run_config
717+
718+
# Cleanup
719+
await table.shutdown()
633720

634-
identifiers = {m.identifier for m in models.data}
635-
assert "test_provider/user-model" in identifiers # User model preserved
636-
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
637-
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
638721

639-
# Verify sources are correct
640-
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
641-
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
722+
async def test_models_filter_persistent_models_from_removed_providers(cached_disk_dist_registry):
723+
"""Test that models from removed providers are filtered out from persistent models."""
724+
from llama_stack.apis.models import ModelType
725+
from llama_stack.core.datatypes import ModelWithOwner, Provider, RegistryEntrySource, StackRunConfig
726+
from llama_stack.core.routing_tables.models import ModelsRoutingTable
727+
728+
# Create a routing table
729+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
730+
await table.initialize()
731+
732+
# Create some mock persistent models
733+
model1 = ModelWithOwner(
734+
identifier="test_provider_1/model1",
735+
provider_resource_id="model1",
736+
provider_id="test_provider_1",
737+
metadata={},
738+
model_type=ModelType.llm,
739+
source=RegistryEntrySource.listed_from_provider,
740+
)
741+
model2 = ModelWithOwner(
742+
identifier="test_provider_2/model2",
743+
provider_resource_id="model2",
744+
provider_id="test_provider_2",
745+
metadata={},
746+
model_type=ModelType.llm,
747+
source=RegistryEntrySource.listed_from_provider,
748+
)
749+
user_model = ModelWithOwner(
750+
identifier="user_model",
751+
provider_resource_id="user_model",
752+
provider_id="test_provider_1",
753+
metadata={},
754+
model_type=ModelType.llm,
755+
source=RegistryEntrySource.via_register_api,
756+
)
757+
758+
# Create a run config that only includes test_provider_1 (test_provider_2 is removed)
759+
run_config = StackRunConfig(
760+
image_name="test",
761+
providers={
762+
"inference": [
763+
Provider(provider_id="test_provider_1", provider_type="openai", config={"api_key": "test_key"}),
764+
# test_provider_2 is removed from run.yaml
765+
]
766+
},
767+
models=[],
768+
)
642769

643-
assert user_model.source == RegistryEntrySource.via_register_api
644-
assert provider_model.source == RegistryEntrySource.listed_from_provider
770+
# Set the run config
771+
table.current_run_config = run_config
772+
await table.cleanup_disabled_provider_models()
773+
await table.register_from_config_models()
774+
775+
# Test the cleanup logic directly
776+
# First, manually add models to the registry to simulate existing models
777+
await table.dist_registry.register(model1)
778+
await table.dist_registry.register(model2)
779+
await table.dist_registry.register(user_model)
780+
781+
# Now set the run config which should trigger cleanup
782+
table.current_run_config = run_config
783+
await table.cleanup_disabled_provider_models()
784+
await table.register_from_config_models()
785+
786+
# Get the list of models after cleanup
787+
response = await table.list_models()
788+
model_identifiers = {m.identifier for m in response.data}
789+
790+
# Should have user_model (user-registered) and model1 (from enabled provider), but not model2 (from disabled provider)
791+
# model1 should be kept because test_provider_1 is in the run config (enabled)
792+
# model2 should be removed because test_provider_2 is not in the run config (disabled)
793+
# user_model should be kept because it's user-registered
794+
assert "user_model" in model_identifiers
795+
assert "test_provider_1/model1" in model_identifiers
796+
assert "test_provider_2/model2" not in model_identifiers
797+
798+
# Test that user-registered models are always kept regardless of provider status
799+
user_model_found = next((m for m in response.data if m.identifier == "user_model"), None)
800+
assert user_model_found is not None
801+
assert user_model_found.source == RegistryEntrySource.via_register_api
645802

646803
# Cleanup
647804
await table.shutdown()

0 commit comments

Comments
 (0)