Skip to content

Commit 12d9a08

Browse files
committed
test: add tests for model not persistant models
1 parent 0f208c9 commit 12d9a08

File tree

1 file changed

+146
-85
lines changed

1 file changed

+146
-85
lines changed

tests/unit/distribution/routers/test_routing_tables.py

Lines changed: 146 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from llama_stack.apis.shields.shields import Shield
1818
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
1919
from llama_stack.apis.vector_dbs import VectorDB
20-
from llama_stack.core.datatypes import RegistryEntrySource
20+
from llama_stack.core.datatypes import RegistryEntrySource, StackRunConfig
2121
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
2222
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
2323
from llama_stack.core.routing_tables.models import ModelsRoutingTable
@@ -534,114 +534,175 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry):
534534
await table.shutdown()
535535

536536

537-
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
538-
"""Test that provider refresh preserves user-registered models with default source."""
539-
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
537+
async def test_models_dynamic_from_config_generation(cached_disk_dist_registry):
538+
"""Test that from_config models are generated dynamically from run_config."""
539+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
540540
await table.initialize()
541541

542-
# First register a user model with same provider_resource_id as provider will later provide
543-
await table.register_model(
544-
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
542+
# Test that _generate_from_config_models returns empty list when no run_config
543+
from_config_models = table._generate_from_config_models()
544+
assert len(from_config_models) == 0
545+
546+
# Create a run config with from_config models
547+
run_config = StackRunConfig(
548+
image_name="test",
549+
providers={},
550+
models=[
551+
{
552+
"model_id": "from_config_model_1",
553+
"provider_id": "test_provider",
554+
"model_type": "llm",
555+
"provider_model_id": "gpt-3.5-turbo",
556+
},
557+
{
558+
"model_id": "from_config_model_2",
559+
"provider_id": "test_provider",
560+
"model_type": "llm",
561+
"provider_model_id": "gpt-4",
562+
},
563+
],
545564
)
546565

547-
# Verify user model is registered with default source
548-
models = await table.list_models()
549-
assert len(models.data) == 1
550-
user_model = models.data[0]
551-
assert user_model.source == RegistryEntrySource.via_register_api
552-
assert user_model.identifier == "my-custom-alias"
553-
assert user_model.provider_resource_id == "provider-model-1"
566+
# Set the run config
567+
table.set_run_config(run_config)
554568

555-
# Now simulate provider refresh
556-
provider_models = [
557-
Model(
558-
identifier="provider-model-1",
559-
provider_resource_id="provider-model-1",
560-
provider_id="test_provider",
561-
metadata={},
562-
model_type=ModelType.llm,
563-
),
564-
Model(
565-
identifier="different-model",
566-
provider_resource_id="different-model",
567-
provider_id="test_provider",
568-
metadata={},
569-
model_type=ModelType.llm,
570-
),
571-
]
572-
await table.update_registered_models("test_provider", provider_models)
569+
# Test _generate_from_config_models returns the models
570+
from_config_models = table._generate_from_config_models()
571+
assert len(from_config_models) == 2
573572

574-
# Verify user model with alias is preserved, but provider added new model
575-
models = await table.list_models()
576-
assert len(models.data) == 2
573+
model_identifiers = {m.identifier for m in from_config_models}
574+
assert "from_config_model_1" in model_identifiers
575+
assert "from_config_model_2" in model_identifiers
576+
577+
# Test that from_config models have correct attributes
578+
model_1 = next(m for m in from_config_models if m.identifier == "from_config_model_1")
579+
assert model_1.provider_id == "test_provider"
580+
assert model_1.provider_resource_id == "gpt-3.5-turbo"
581+
assert model_1.model_type == ModelType.llm
582+
assert model_1.source == RegistryEntrySource.from_config
583+
584+
# Cleanup
585+
await table.shutdown()
586+
587+
588+
async def test_models_dynamic_from_config_lookup(cached_disk_dist_registry):
589+
"""Test that from_config models can be looked up individually."""
590+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
591+
await table.initialize()
577592

578-
# Find the user model and provider model
579-
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
580-
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
593+
# Create a run config with from_config models
594+
run_config = StackRunConfig(
595+
image_name="test",
596+
providers={},
597+
models=[
598+
{
599+
"model_id": "lookup_test_model",
600+
"provider_id": "test_provider",
601+
"model_type": "llm",
602+
"provider_model_id": "gpt-3.5-turbo",
603+
}
604+
],
605+
)
581606

582-
assert user_model is not None
583-
assert user_model.source == RegistryEntrySource.via_register_api
584-
assert user_model.provider_resource_id == "provider-model-1"
607+
# Set the run config
608+
table.set_run_config(run_config)
585609

586-
assert provider_model is not None
587-
assert provider_model.source == RegistryEntrySource.listed_from_provider
588-
assert provider_model.provider_resource_id == "different-model"
610+
# Test that we can get the from_config model individually
611+
model = await table.get_model("lookup_test_model")
612+
assert model is not None
613+
assert model.identifier == "lookup_test_model"
614+
assert model.provider_id == "test_provider"
615+
assert model.provider_resource_id == "gpt-3.5-turbo"
616+
assert model.source == RegistryEntrySource.from_config
589617

590618
# Cleanup
591619
await table.shutdown()
592620

593621

594-
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
595-
"""Test that provider refresh removes old provider models but keeps default ones."""
596-
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
622+
async def test_models_dynamic_from_config_mixed_with_persistent(cached_disk_dist_registry):
623+
"""Test that from_config models work alongside persistent models."""
624+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
597625
await table.initialize()
598626

599-
# Register a user model
600-
await table.register_model(model_id="user-model", provider_id="test_provider")
627+
# Create a run config with from_config models
628+
run_config = StackRunConfig(
629+
image_name="test",
630+
providers={},
631+
models=[
632+
{
633+
"model_id": "from_config_model",
634+
"provider_id": "test_provider",
635+
"model_type": "llm",
636+
"provider_model_id": "gpt-3.5-turbo",
637+
}
638+
],
639+
)
601640

602-
# Add some provider models
603-
provider_models_v1 = [
604-
Model(
605-
identifier="provider-model-old",
606-
provider_resource_id="provider-model-old",
607-
provider_id="test_provider",
608-
metadata={},
609-
model_type=ModelType.llm,
610-
),
611-
]
612-
await table.update_registered_models("test_provider", provider_models_v1)
641+
# Set the run config
642+
table.set_run_config(run_config)
613643

614-
# Verify we have both user and provider models
644+
# Test that from_config models are included
615645
models = await table.list_models()
616-
assert len(models.data) == 2
646+
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
617647

618-
# Now update with new provider models (should remove old provider models)
619-
provider_models_v2 = [
620-
Model(
621-
identifier="provider-model-new",
622-
provider_resource_id="provider-model-new",
623-
provider_id="test_provider",
624-
metadata={},
625-
model_type=ModelType.llm,
626-
),
627-
]
628-
await table.update_registered_models("test_provider", provider_models_v2)
648+
assert len(from_config_models) == 1
649+
assert from_config_models[0].identifier == "from_config_model"
629650

630-
# Should have user model + new provider model, old provider model gone
631-
models = await table.list_models()
632-
assert len(models.data) == 2
651+
# Test that we can get the from_config model individually
652+
from_config_model = await table.get_model("from_config_model")
653+
assert from_config_model is not None
654+
assert from_config_model.source == RegistryEntrySource.from_config
633655

634-
identifiers = {m.identifier for m in models.data}
635-
assert "test_provider/user-model" in identifiers # User model preserved
636-
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
637-
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
656+
# Cleanup
657+
await table.shutdown()
658+
659+
660+
async def test_models_dynamic_from_config_disabled_providers(cached_disk_dist_registry):
661+
"""Test that from_config models with disabled providers are skipped."""
662+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
663+
await table.initialize()
664+
665+
# Create a run config with disabled provider models
666+
run_config = StackRunConfig(
667+
image_name="test",
668+
providers={},
669+
models=[
670+
{
671+
"model_id": "enabled_model",
672+
"provider_id": "test_provider",
673+
"model_type": "llm",
674+
"provider_model_id": "gpt-3.5-turbo",
675+
},
676+
{
677+
"model_id": "disabled_model",
678+
"provider_id": "__disabled__",
679+
"model_type": "llm",
680+
"provider_model_id": "gpt-4",
681+
},
682+
],
683+
)
638684

639-
# Verify sources are correct
640-
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
641-
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
685+
# Set the run config
686+
table.set_run_config(run_config)
642687

643-
assert user_model.source == RegistryEntrySource.via_register_api
644-
assert provider_model.source == RegistryEntrySource.listed_from_provider
688+
# Test that only enabled models are included
689+
from_config_models = table._generate_from_config_models()
690+
assert len(from_config_models) == 1
691+
assert from_config_models[0].identifier == "enabled_model"
692+
693+
# Cleanup
694+
await table.shutdown()
695+
696+
697+
async def test_models_dynamic_from_config_no_run_config(cached_disk_dist_registry):
698+
"""Test that from_config models work when no run_config is set."""
699+
table = ModelsRoutingTable({}, cached_disk_dist_registry, {})
700+
await table.initialize()
701+
702+
# Test that list_models works without run_config
703+
models = await table.list_models()
704+
from_config_models = [m for m in models.data if m.source == RegistryEntrySource.from_config]
705+
assert len(from_config_models) == 0 # No from_config models when no run_config
645706

646707
# Cleanup
647708
await table.shutdown()

0 commit comments

Comments
 (0)