Skip to content

Commit f32e42a

Browse files
committed
test: add tests for model cache cleanup functionality
1 parent 62d1d59 commit f32e42a

File tree

1 file changed

+256
-6
lines changed

1 file changed

+256
-6
lines changed

tests/unit/distribution/routers/test_routing_tables.py

Lines changed: 256 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,36 @@ async def test_models_source_tracking_provider(cached_disk_dist_registry):
494494
await table.shutdown()
495495

496496

497+
async def test_models_source_tracking_config(cached_disk_dist_registry):
498+
"""Test that models registered via direct object registration get config source."""
499+
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
500+
await table.initialize()
501+
502+
# Register model directly with config source (simulating config-based registration)
503+
from llama_stack.core.datatypes import ModelWithOwner
504+
505+
config_model = ModelWithOwner(
506+
identifier="config-model",
507+
provider_resource_id="config-model",
508+
provider_id="test_provider",
509+
metadata={},
510+
model_type=ModelType.llm,
511+
source=RegistryEntrySource.from_config,
512+
)
513+
514+
# Use the internal register_object method to simulate config registration
515+
await table.register_object(config_model)
516+
517+
models = await table.list_models()
518+
assert len(models.data) == 1
519+
model = models.data[0]
520+
assert model.source == RegistryEntrySource.from_config
521+
assert model.identifier == "config-model"
522+
523+
# Cleanup
524+
await table.shutdown()
525+
526+
497527
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
498528
"""Test that provider refresh preserves user-registered models with default source."""
499529
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
@@ -551,14 +581,94 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
551581
await table.shutdown()
552582

553583

584+
async def test_models_source_interaction_preserves_config(cached_disk_dist_registry):
585+
"""Test that provider refresh preserves config-registered models."""
586+
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
587+
await table.initialize()
588+
589+
# First register a config model with same provider_resource_id as provider will later provide
590+
from llama_stack.core.datatypes import ModelWithOwner
591+
592+
config_model = ModelWithOwner(
593+
identifier="config-model",
594+
provider_resource_id="provider-model-1",
595+
provider_id="test_provider",
596+
metadata={},
597+
model_type=ModelType.llm,
598+
source=RegistryEntrySource.from_config,
599+
)
600+
601+
await table.register_object(config_model)
602+
603+
# Verify config model is registered with config source
604+
models = await table.list_models()
605+
assert len(models.data) == 1
606+
config_model_obj = models.data[0]
607+
assert config_model_obj.source == RegistryEntrySource.from_config
608+
assert config_model_obj.identifier == "config-model"
609+
assert config_model_obj.provider_resource_id == "provider-model-1"
610+
611+
# Now simulate provider refresh
612+
provider_models = [
613+
Model(
614+
identifier="provider-model-1",
615+
provider_resource_id="provider-model-1",
616+
provider_id="test_provider",
617+
metadata={},
618+
model_type=ModelType.llm,
619+
),
620+
Model(
621+
identifier="different-model",
622+
provider_resource_id="different-model",
623+
provider_id="test_provider",
624+
metadata={},
625+
model_type=ModelType.llm,
626+
),
627+
]
628+
await table.update_registered_models("test_provider", provider_models)
629+
630+
# Verify config model is preserved, but provider added new model
631+
models = await table.list_models()
632+
assert len(models.data) == 2
633+
634+
# Find the config model and provider model
635+
config_model_obj = next((m for m in models.data if m.identifier == "config-model"), None)
636+
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
637+
638+
assert config_model_obj is not None
639+
assert config_model_obj.source == RegistryEntrySource.from_config
640+
assert config_model_obj.provider_resource_id == "provider-model-1"
641+
642+
assert provider_model is not None
643+
assert provider_model.source == RegistryEntrySource.listed_from_provider
644+
assert provider_model.provider_resource_id == "different-model"
645+
646+
# Cleanup
647+
await table.shutdown()
648+
649+
554650
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
555-
"""Test that provider refresh removes old provider models but keeps default ones."""
651+
"""Test that provider refresh removes old provider models but keeps default and config ones."""
556652
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
557653
await table.initialize()
558654

559655
# Register a user model
560656
await table.register_model(model_id="user-model", provider_id="test_provider")
561657

658+
# Register a config model
659+
from llama_stack.core.datatypes import ModelWithOwner
660+
661+
config_model = ModelWithOwner(
662+
identifier="config-model",
663+
provider_resource_id="config-model",
664+
provider_id="test_provider",
665+
metadata={},
666+
model_type=ModelType.llm,
667+
source=RegistryEntrySource.from_config,
668+
)
669+
670+
await table.register_object(config_model)
671+
562672
# Add some provider models
563673
provider_models_v1 = [
564674
Model(
@@ -571,9 +681,9 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
571681
]
572682
await table.update_registered_models("test_provider", provider_models_v1)
573683

574-
# Verify we have both user and provider models
684+
# Verify we have user, config, and provider models
575685
models = await table.list_models()
576-
assert len(models.data) == 2
686+
assert len(models.data) == 3
577687

578688
# Now update with new provider models (should remove old provider models)
579689
provider_models_v2 = [
@@ -587,20 +697,160 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
587697
]
588698
await table.update_registered_models("test_provider", provider_models_v2)
589699

590-
# Should have user model + new provider model, old provider model gone
700+
# Should have user model + config model + new provider model, old provider model gone
591701
models = await table.list_models()
592-
assert len(models.data) == 2
702+
assert len(models.data) == 3
593703

594704
identifiers = {m.identifier for m in models.data}
595705
assert "test_provider/user-model" in identifiers # User model preserved
596-
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
706+
assert "config-model" in identifiers # Config model preserved
707+
assert "test_provider/provider-model-new" in identifiers # New provider model
597708
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
598709

599710
# Verify sources are correct
600711
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
712+
config_model_obj = next((m for m in models.data if m.identifier == "config-model"), None)
601713
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
602714

603715
assert user_model.source == RegistryEntrySource.via_register_api
716+
assert config_model_obj.source == RegistryEntrySource.from_config
717+
assert provider_model.source == RegistryEntrySource.listed_from_provider
718+
719+
# Cleanup
720+
await table.shutdown()
721+
722+
723+
async def test_models_cleanup_ephemeral_models(cached_disk_dist_registry):
724+
"""Test that cleanup_ephemeral_models only removes provider models."""
725+
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
726+
await table.initialize()
727+
728+
# Register models with different sources
729+
await table.register_model(model_id="user-model", provider_id="test_provider")
730+
731+
from llama_stack.core.datatypes import ModelWithOwner
732+
733+
config_model = ModelWithOwner(
734+
identifier="config-model",
735+
provider_resource_id="config-model",
736+
provider_id="test_provider",
737+
metadata={},
738+
model_type=ModelType.llm,
739+
source=RegistryEntrySource.from_config,
740+
)
741+
742+
await table.register_object(config_model)
743+
744+
# Add provider models
745+
provider_models = [
746+
Model(
747+
identifier="provider-model",
748+
provider_resource_id="provider-model",
749+
provider_id="test_provider",
750+
metadata={},
751+
model_type=ModelType.llm,
752+
),
753+
]
754+
await table.update_registered_models("test_provider", provider_models)
755+
756+
# Verify we have all three types
757+
models = await table.list_models()
758+
assert len(models.data) == 3
759+
760+
# Call cleanup_ephemeral_models
761+
await table.cleanup_ephemeral_models()
762+
763+
# Should only have user and config models, provider model removed
764+
models = await table.list_models()
765+
assert len(models.data) == 2
766+
767+
identifiers = {m.identifier for m in models.data}
768+
assert "test_provider/user-model" in identifiers # User model preserved
769+
assert "config-model" in identifiers # Config model preserved
770+
assert "test_provider/provider-model" not in identifiers # Provider model removed
771+
772+
# Cleanup
773+
await table.shutdown()
774+
775+
776+
async def test_models_cleanup_config_models(cached_disk_dist_registry):
777+
"""Test that cleanup_config_models only removes config models."""
778+
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
779+
await table.initialize()
780+
781+
# Register models with different sources
782+
await table.register_model(model_id="user-model", provider_id="test_provider")
783+
784+
from llama_stack.core.datatypes import ModelWithOwner
785+
786+
config_model = ModelWithOwner(
787+
identifier="config-model",
788+
provider_resource_id="config-model",
789+
provider_id="test_provider",
790+
metadata={},
791+
model_type=ModelType.llm,
792+
source=RegistryEntrySource.from_config,
793+
)
794+
795+
await table.register_object(config_model)
796+
797+
# Add provider models
798+
provider_models = [
799+
Model(
800+
identifier="provider-model",
801+
provider_resource_id="provider-model",
802+
provider_id="test_provider",
803+
metadata={},
804+
model_type=ModelType.llm,
805+
),
806+
]
807+
await table.update_registered_models("test_provider", provider_models)
808+
809+
# Verify we have all three types
810+
models = await table.list_models()
811+
assert len(models.data) == 3
812+
813+
# Call cleanup_config_models
814+
await table.cleanup_config_models()
815+
816+
# Should only have user and provider models, config model removed
817+
models = await table.list_models()
818+
assert len(models.data) == 2
819+
820+
identifiers = {m.identifier for m in models.data}
821+
assert "test_provider/user-model" in identifiers # User model preserved
822+
assert "test_provider/provider-model" in identifiers # Provider model preserved
823+
assert "config-model" not in identifiers # Config model removed
824+
825+
# Cleanup
826+
await table.shutdown()
827+
828+
829+
async def test_models_register_model_with_source_parameter(cached_disk_dist_registry):
830+
"""Test that register_model accepts and uses the source parameter."""
831+
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
832+
await table.initialize()
833+
834+
# Test registering with explicit config source
835+
await table.register_model(
836+
model_id="config-model", provider_id="test_provider", source=RegistryEntrySource.from_config
837+
)
838+
839+
models = await table.list_models()
840+
assert len(models.data) == 1
841+
model = models.data[0]
842+
assert model.source == RegistryEntrySource.from_config
843+
assert model.identifier == "test_provider/config-model"
844+
845+
# Test registering with explicit provider source
846+
await table.register_model(
847+
model_id="provider-model", provider_id="test_provider", source=RegistryEntrySource.listed_from_provider
848+
)
849+
850+
models = await table.list_models()
851+
assert len(models.data) == 2
852+
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model"), None)
853+
assert provider_model is not None
604854
assert provider_model.source == RegistryEntrySource.listed_from_provider
605855

606856
# Cleanup

0 commit comments

Comments
 (0)