Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch requests to artifact endpoint make mr panic #718

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions clients/python/tests/regression_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import requests

from model_registry import ModelRegistry
from model_registry.types.artifacts import ModelArtifact
Expand Down Expand Up @@ -99,3 +100,35 @@ async def test_create_standalone_model_artifact(client: ModelRegistry):
assert mv.id
mv_ma = await client._api.upsert_model_version_artifact(new_ma, mv.id)
assert mv_ma.id == new_ma.id

@pytest.mark.e2e
async def test_patch_model_artifacts_artifact_type(client: ModelRegistry):
"""Patching Artifacts makes the model registry server panic.

reported with https://issues.redhat.com/browse/RHOAIENG-16932
"""
name = "test_model"
version = "1.0.0"
rm = client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
)
assert rm.id
mv = client.get_model_version(name, version)
assert mv
assert mv.id
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id

payload = { "modelFormatName": "foo", "artifactType": "model-artifact" }
from .conftest import REGISTRY_HOST, REGISTRY_PORT
response = requests.patch(url=f"{REGISTRY_HOST}:{REGISTRY_PORT}/api/model_registry/v1alpha3/artifacts/{ma.id}", json=payload, timeout=10, headers={"Content-Type": "application/json"})
assert response.status_code == 200
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id
assert ma.model_format_name == "foo"
26 changes: 18 additions & 8 deletions internal/converter/openapi_reconciler_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,28 @@ import (

func UpdateExistingArtifact(genc OpenAPIReconciler, source OpenapiUpdateWrapper[openapi.Artifact]) (openapi.Artifact, error) {
art := InitWithExisting(source)

if source.Update == nil {
return art, nil
}
ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact})
if err != nil {
return art, err

if source.Update.ModelArtifact != nil {
ma, err := genc.UpdateExistingModelArtifact(OpenapiUpdateWrapper[openapi.ModelArtifact]{Existing: art.ModelArtifact, Update: source.Update.ModelArtifact})
if err != nil {
return art, err
}

art.ModelArtifact = &ma
}
da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact})
if err != nil {
return art, err

if source.Update.DocArtifact != nil {
da, err := genc.UpdateExistingDocArtifact(OpenapiUpdateWrapper[openapi.DocArtifact]{Existing: art.DocArtifact, Update: source.Update.DocArtifact})
if err != nil {
return art, err
}

art.DocArtifact = &da
}
art.DocArtifact = &da
art.ModelArtifact = &ma

return art, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ func (s *ModelRegistryServiceAPIService) UpdateArtifact(ctx context.Context, art
}
if artifactUpdate.DocArtifactUpdate != nil {
entity.DocArtifact.Id = &artifactId
} else {
}
if artifactUpdate.ModelArtifactUpdate != nil {
entity.ModelArtifact.Id = &artifactId
}
existing, err := s.coreApi.GetArtifactById(artifactId)
Expand Down
Loading