Skip to content

Commit d24c989

Browse files
authored
don't delete models on provider update (#836)
This makes sure that the foreign key references stay intact in the muxing table. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent 16d525f commit d24c989

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

src/codegate/db/connection.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,15 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
469469
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
470470
return added_model
471471

472-
async def delete_provider_models(self, provider_id: str):
472+
async def delete_provider_model(self, provider_id: str, model: str) -> Optional[ProviderModel]:
473473
sql = text(
474474
"""
475475
DELETE FROM provider_models
476-
WHERE provider_endpoint_id = :provider_endpoint_id
476+
WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
477477
"""
478478
)
479-
conditions = {"provider_endpoint_id": provider_id}
479+
480+
conditions = {"provider_endpoint_id": provider_id, "name": model}
480481
await self._execute_with_no_return(sql, conditions)
481482

482483
async def delete_muxes_by_workspace(self, workspace_id: str):

src/codegate/providers/crud/crud.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -141,26 +141,40 @@ async def update_endpoint(
141141
except Exception as err:
142142
raise ValueError("Unable to get models from provider: {}".format(str(err)))
143143

144-
# Reset all provider models.
145-
await self._db_writer.delete_provider_models(str(endpoint.id))
144+
models_set = set(models)
146145

147-
for model in models:
146+
# Get the models from the provider
147+
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id))
148+
149+
models_in_db_set = set(model.name for model in models_in_db)
150+
151+
# Add the models that are in the provider but not in the DB
152+
for model in models_set - models_in_db_set:
148153
await self._db_writer.add_provider_model(
149154
dbmodels.ProviderModel(
150155
provider_endpoint_id=founddbe.id,
151156
name=model,
152157
)
153158
)
154159

160+
# Remove the models that are in the DB but not in the provider
161+
for model in models_in_db_set - models_set:
162+
await self._db_writer.delete_provider_model(
163+
founddbe.id,
164+
model,
165+
)
166+
155167
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())
156168

157-
await self._db_writer.push_provider_auth_material(
158-
dbmodels.ProviderAuthMaterial(
159-
provider_endpoint_id=dbendpoint.id,
160-
auth_type=endpoint.auth_type,
161-
auth_blob=endpoint.api_key if endpoint.api_key else "",
169+
# If an API key was provided or we've changed the auth type, we update the auth material
170+
if endpoint.auth_type != founddbe.auth_type or endpoint.api_key:
171+
await self._db_writer.push_provider_auth_material(
172+
dbmodels.ProviderAuthMaterial(
173+
provider_endpoint_id=dbendpoint.id,
174+
auth_type=endpoint.auth_type,
175+
auth_blob=endpoint.api_key if endpoint.api_key else "",
176+
)
162177
)
163-
)
164178

165179
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
166180

0 commit comments

Comments
 (0)