Skip to content

Commit

Permalink
Implement gRPC GetArtifactType (kubeflow#26)
Browse files Browse the repository at this point in the history
* Implement gRPC GetArtifactType

* Add test client code to current prototype IT
  • Loading branch information
tarilabs authored Sep 27, 2023
1 parent 0ac30c6 commit 7293d51
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
35 changes: 32 additions & 3 deletions internal/server/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,38 @@ func (g grpcServer) PutParentContexts(ctx context.Context, request *proto.PutPar
panic("implement me")
}

func (g grpcServer) GetArtifactType(ctx context.Context, request *proto.GetArtifactTypeRequest) (*proto.GetArtifactTypeResponse, error) {
//TODO implement me
panic("implement me")
func (g grpcServer) GetArtifactType(ctx context.Context, request *proto.GetArtifactTypeRequest) (resp *proto.GetArtifactTypeResponse, err error) {
ctx, dbConn := Begin(ctx, g.dbConnection)
defer handleTransaction(ctx, &err)

err = requiredFields(REQUIRED_TYPE_FIELDS, request.TypeName)
response := &proto.GetArtifactTypeResponse{}

var results []db.Type
rx := dbConn.Find(&results, db.Type{Name: *request.TypeName, TypeKind: int32(ARTIFACT_TYPE), Version: request.TypeVersion})
if rx.Error != nil {
return nil, rx.Error
}
if len(results) > 1 {
return nil, fmt.Errorf("more than one type found: %v", len(results))
}
if len(results) == 0 {
return response, nil
}

r0 := results[0]
artifactType := proto.ArtifactType{
Id: &r0.ID,
Name: &r0.Name,
Version: r0.Version,
Description: r0.Description,
ExternalId: r0.ExternalID,
}
for _, v := range r0.Properties {
artifactType.Properties[v.Name] = proto.PropertyType(v.DataType)
}
response.ArtifactType = &artifactType
return response, nil
}

func (g grpcServer) GetArtifactTypesByID(ctx context.Context, request *proto.GetArtifactTypesByIDRequest) (*proto.GetArtifactTypesByIDResponse, error) {
Expand Down
6 changes: 6 additions & 0 deletions test/python/test_mlmetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def main():
response = store.PutArtifactType(request)
model_type_id = response.type_id

request = metadata_store_service_pb2.GetArtifactTypeRequest()
request.type_name = "SavedModel"
response = store.GetArtifactType(request)
assert response.artifact_type.id == 2
assert response.artifact_type.name == "SavedModel"

# Query all registered Artifact types.
# artifact_types = store.GetArtifactTypes()

Expand Down

0 comments on commit 7293d51

Please sign in to comment.