From 44d738e02ce6ad3b0a80d2ea4993b99e912da20a Mon Sep 17 00:00:00 2001 From: Alessio Pragliola Date: Tue, 14 Jan 2025 21:22:08 +0100 Subject: [PATCH 1/3] feat(proxy): make the proxy resilient on mlmd failure Signed-off-by: Alessio Pragliola --- api/openapi/model-registry.yaml | 86 +++++ cmd/proxy.go | 114 ++++-- internal/grpc/backoff.go | 67 ++++ internal/mlmdtypes/mlmdtypes.go | 14 +- internal/proxy/dynamic_router.go | 39 ++ pkg/api/error.go | 10 + pkg/openapi/api_model_registry_service.go | 440 ++++++++++++++++++++++ 7 files changed, 730 insertions(+), 40 deletions(-) create mode 100644 internal/grpc/backoff.go create mode 100644 internal/proxy/dynamic_router.go diff --git a/api/openapi/model-registry.yaml b/api/openapi/model-registry.yaml index ccbd67430..f70058ad5 100644 --- a/api/openapi/model-registry.yaml +++ b/api/openapi/model-registry.yaml @@ -28,6 +28,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: findArtifact summary: Get an Artifact that matches search parameters. description: Gets the details of a single instance of an `Artifact` that matches search parameters. @@ -58,6 +60,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getArtifacts summary: List All Artifacts description: Gets a list of all `Artifact` entities. @@ -80,6 +84,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createArtifact summary: Create an Artifact description: Creates a new instance of an `Artifact`. @@ -99,6 +105,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getArtifact summary: Get an Artifact description: Gets the details of a single instance of an `Artifact`. @@ -123,6 +131,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: updateArtifact summary: Update an Artifact description: Updates an existing `Artifact`. @@ -151,6 +161,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: findModelArtifact summary: Get a ModelArtifact that matches search parameters. description: Gets the details of a single instance of a `ModelArtifact` that matches search parameters. @@ -181,6 +193,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getModelArtifacts summary: List All ModelArtifacts description: Gets a list of all `ModelArtifact` entities. @@ -203,6 +217,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createModelArtifact summary: Create a ModelArtifact description: Creates a new instance of a `ModelArtifact`. @@ -222,6 +238,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getModelArtifact summary: Get a ModelArtifact description: Gets the details of a single instance of a `ModelArtifact`. @@ -246,6 +264,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: updateModelArtifact summary: Update a ModelArtifact description: Updates an existing `ModelArtifact`. @@ -275,6 +295,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getModelVersions summary: List All ModelVersions description: Gets a list of all `ModelVersion` entities. @@ -297,6 +319,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createModelVersion summary: Create a ModelVersion description: Creates a new instance of a `ModelVersion`. @@ -316,6 +340,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getModelVersion summary: Get a ModelVersion description: Gets the details of a single instance of a `ModelVersion`. @@ -340,6 +366,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: updateModelVersion summary: Update a ModelVersion description: Updates an existing `ModelVersion`. @@ -366,6 +394,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: findRegisteredModel summary: Get a RegisteredModel that matches search parameters. description: Gets the details of a single instance of a `RegisteredModel` that matches search parameters. @@ -391,6 +421,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getRegisteredModels summary: List All RegisteredModels description: Gets a list of all `RegisteredModel` entities. @@ -413,6 +445,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createRegisteredModel summary: Create a RegisteredModel description: Creates a new instance of a `RegisteredModel`. @@ -432,6 +466,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getRegisteredModel summary: Get a RegisteredModel description: Gets the details of a single instance of a `RegisteredModel`. @@ -456,6 +492,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: updateRegisteredModel summary: Update a RegisteredModel description: Updates an existing `RegisteredModel`. @@ -489,6 +527,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getModelVersionArtifacts summary: List all artifacts associated with the `ModelVersion` post: @@ -514,6 +554,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: upsertModelVersionArtifact summary: Upsert an Artifact in a ModelVersion description: Creates a new instance of an Artifact if needed and associates it with `ModelVersion`. @@ -547,6 +589,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getRegisteredModelVersions summary: List All RegisteredModel's ModelVersions description: Gets a list of all `ModelVersion` entities for the `RegisteredModel`. @@ -571,6 +615,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createRegisteredModelVersion summary: Create a ModelVersion in RegisteredModel description: Creates a new instance of a `ModelVersion`. @@ -599,6 +645,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: findInferenceService summary: Get an InferenceServices that matches search parameters. description: Gets the details of a single instance of `InferenceService` that matches search parameters. @@ -622,6 +670,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getInferenceService summary: Get a InferenceService description: Gets the details of a single instance of a `InferenceService`. @@ -646,6 +696,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: updateInferenceService summary: Update a InferenceService description: Updates an existing `InferenceService`. @@ -679,6 +731,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getInferenceServices summary: List All InferenceServices description: Gets a list of all `InferenceService` entities. @@ -701,6 +755,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createInferenceService summary: Create a InferenceService description: Creates a new instance of a `InferenceService`. @@ -720,6 +776,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: findServingEnvironment summary: Find ServingEnvironment description: Finds a `ServingEnvironment` entity that matches query parameters. @@ -745,6 +803,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getServingEnvironments summary: List All ServingEnvironments description: Gets a list of all `ServingEnvironment` entities. @@ -767,6 +827,8 @@ paths: $ref: "#/components/responses/Unauthorized" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createServingEnvironment summary: Create a ServingEnvironment description: Creates a new instance of a `ServingEnvironment`. @@ -786,6 +848,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getServingEnvironment summary: Get a ServingEnvironment description: Gets the details of a single instance of a `ServingEnvironment`. @@ -810,6 +874,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: updateServingEnvironment summary: Update a ServingEnvironment description: Updates an existing `ServingEnvironment`. @@ -843,6 +909,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getEnvironmentInferenceServices summary: List All ServingEnvironment's InferenceServices description: Gets a list of all `InferenceService` entities for the `ServingEnvironment`. @@ -867,6 +935,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createEnvironmentInferenceService summary: Create a InferenceService in ServingEnvironment description: Creates a new instance of a `InferenceService`. @@ -900,6 +970,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getInferenceServiceServes summary: List All InferenceService's ServeModel actions description: Gets a list of all `ServeModel` entities for the `InferenceService`. @@ -924,6 +996,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: createInferenceServiceServe summary: Create a ServeModel action in a InferenceService description: Creates a new instance of a `ServeModel` associated with `InferenceService`. @@ -950,6 +1024,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getInferenceServiceModel summary: Get InferenceService's RegisteredModel description: Gets the `RegisteredModel` entity for the `InferenceService`. @@ -976,6 +1052,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: getInferenceServiceVersion summary: Get InferenceService's ModelVersion description: Gets the `ModelVersion` entity for the `InferenceService`. @@ -1004,6 +1082,8 @@ paths: $ref: "#/components/responses/NotFound" "500": $ref: "#/components/responses/InternalServerError" + "503": + $ref: "#/components/responses/ServiceUnavailable" operationId: findModelVersion summary: Get a ModelVersion that matches search parameters. description: Gets the details of a single instance of a `ModelVersion` that matches search parameters. @@ -1667,6 +1747,12 @@ components: schema: $ref: "#/components/schemas/Error" description: Unexpected internal server error + ServiceUnavailable: + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + description: Service is unavailable ModelArtifactListResponse: content: application/json: diff --git a/cmd/proxy.go b/cmd/proxy.go index c313607b5..6ddb62279 100644 --- a/cmd/proxy.go +++ b/cmd/proxy.go @@ -1,13 +1,14 @@ package cmd import ( - "context" "fmt" "net/http" - "time" + "sync" "github.com/golang/glog" + mrGrpc "github.com/kubeflow/model-registry/internal/grpc" "github.com/kubeflow/model-registry/internal/mlmdtypes" + "github.com/kubeflow/model-registry/internal/proxy" "github.com/kubeflow/model-registry/internal/server/openapi" "github.com/kubeflow/model-registry/pkg/core" "github.com/spf13/cobra" @@ -15,6 +16,8 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +const mlmdUnavailableMessage = "MLMD server is down or unavailable. Please check that the database is reachable and try again later." + // proxyCmd represents the proxy command var proxyCmd = &cobra.Command{ Use: "proxy", @@ -27,42 +30,87 @@ hostname and port where it listens.'`, } func runProxyServer(cmd *cobra.Command, args []string) error { - glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port) - - ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort) - glog.Infof("connecting to MLMD server %s..", mlmdAddr) - conn, err := grpc.DialContext( // nolint:staticcheck - ctxTimeout, - mlmdAddr, - grpc.WithReturnConnectionError(), // nolint:staticcheck - grpc.WithBlock(), // nolint:staticcheck - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - return fmt.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err) - } - defer conn.Close() - glog.Infof("connected to MLMD server") + var conn *grpc.ClientConn + var err error + var wg sync.WaitGroup - mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() - _, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig) - if err != nil { - return fmt.Errorf("error creating MLMD types: %v", err) - } - service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) + errChan := make(chan error, 1) + + router := proxy.NewDynamicRouter() + + router.SetRouter(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, mlmdUnavailableMessage, http.StatusServiceUnavailable) + })) + + // Start the connection to the MLMD server in a separate goroutine, so that + // we can start the proxy server and start serving requests while we wait + // for the connection to be established. + go func() { + defer close(errChan) + + mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort) + glog.Infof("connecting to MLMD server %s..", mlmdAddr) + conn, err = grpc.NewClient(mlmdAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + errChan <- fmt.Errorf("error dialing connection to mlmd server %s: %w", mlmdAddr, err) + + return + } + + mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() + + _, err = mrGrpc.RetryOnGRPCError[map[string]int64](mlmdtypes.CreateMLMDTypes, conn, mlmdTypeNamesConfig) + if err != nil { + errChan <- fmt.Errorf("error creating MLMD types: %w", err) + + return + } + service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) + if err != nil { + errChan <- fmt.Errorf("error creating core service: %w", err) + + return + } + + ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service) + ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService) + + router.SetRouter(openapi.NewRouter(ModelRegistryServiceAPIController)) + + glog.Infof("connected to MLMD server") + }() + + wg.Add(1) + + // Start the proxy server in a separate goroutine so that we can handle + // errors from both the proxy server and the connection to the MLMD server. + go func() { + defer wg.Done() + + glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port) + + err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) + if err != nil { + errChan <- err + } + }() + + defer func() { + if conn != nil { + glog.Info("closing connection to MLMD server") + + conn.Close() + } + }() + + err = <-errChan if err != nil { - return fmt.Errorf("error creating core service: %v", err) + return fmt.Errorf("error starting proxy server: %w", err) } - ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service) - ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService) - - router := openapi.NewRouter(ModelRegistryServiceAPIController) + // Wait for the proxy server to finish serving requests. + wg.Wait() - glog.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)) return nil } diff --git a/internal/grpc/backoff.go b/internal/grpc/backoff.go new file mode 100644 index 000000000..dd68b7544 --- /dev/null +++ b/internal/grpc/backoff.go @@ -0,0 +1,67 @@ +package grpc + +import ( + "fmt" + "reflect" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + // maxRetryAttempts is the maximum number of times to retry a request. + maxRetryAttempts = 25 +) + +var retryableStatusCodes = map[codes.Code]bool{ + codes.Unavailable: true, +} + +func RetryOnGRPCError[T any](funcToRetry any, funcParams ...any) (T, error) { + var outErr error + var res T + + fnVal := reflect.ValueOf(funcToRetry) + if fnVal.Kind() != reflect.Func { + return res, fmt.Errorf("grpc retry error: function parameter is not a function") + } + + if len(funcParams) != fnVal.Type().NumIn() { + return res, fmt.Errorf("grpc retry error: function parameters count mismatch") + } + + inputs := make([]reflect.Value, len(funcParams)) + for i, param := range funcParams { + inputs[i] = reflect.ValueOf(param) + } + + for i := 0; i < maxRetryAttempts; i++ { + outs := fnVal.Call(inputs) + + if len(outs) != 2 { + return res, fmt.Errorf("grpc retry error: function from input does not return 2 values") + } + + res = outs[0].Interface().(T) + errI := outs[1].Interface() + if errI == nil { + outErr = nil + } else { + outErr = errI.(error) + } + + if status, ok := status.FromError(outErr); ok { + if !retryableStatusCodes[status.Code()] { + break + } + } else { + break + } + + backoff := time.Duration(i+1) * time.Second + time.Sleep(backoff) + } + + return res, outErr +} diff --git a/internal/mlmdtypes/mlmdtypes.go b/internal/mlmdtypes/mlmdtypes.go index 7a169224c..7a7f6effa 100644 --- a/internal/mlmdtypes/mlmdtypes.go +++ b/internal/mlmdtypes/mlmdtypes.go @@ -128,37 +128,37 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface, nameConfig MLMDTypeNamesConfig registeredModelResp, err := client.PutContextType(context.Background(), ®isteredModelReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.RegisteredModelTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.RegisteredModelTypeName, err) } modelVersionResp, err := client.PutContextType(context.Background(), &modelVersionReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ModelVersionTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.ModelVersionTypeName, err) } docArtifactResp, err := client.PutArtifactType(context.Background(), &docArtifactReq) if err != nil { - return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.DocArtifactTypeName, err) + return nil, fmt.Errorf("error setting up artifact type %s: %w", nameConfig.DocArtifactTypeName, err) } modelArtifactResp, err := client.PutArtifactType(context.Background(), &modelArtifactReq) if err != nil { - return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.ModelArtifactTypeName, err) + return nil, fmt.Errorf("error setting up artifact type %s: %w", nameConfig.ModelArtifactTypeName, err) } servingEnvironmentResp, err := client.PutContextType(context.Background(), &servingEnvironmentReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ServingEnvironmentTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.ServingEnvironmentTypeName, err) } inferenceServiceResp, err := client.PutContextType(context.Background(), &inferenceServiceReq) if err != nil { - return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.InferenceServiceTypeName, err) + return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.InferenceServiceTypeName, err) } serveModelResp, err := client.PutExecutionType(context.Background(), &serveModelReq) if err != nil { - return nil, fmt.Errorf("error setting up execution type %s: %v", nameConfig.ServeModelTypeName, err) + return nil, fmt.Errorf("error setting up execution type %s: %w", nameConfig.ServeModelTypeName, err) } typesMap := map[string]int64{ diff --git a/internal/proxy/dynamic_router.go b/internal/proxy/dynamic_router.go new file mode 100644 index 000000000..343f9e3ff --- /dev/null +++ b/internal/proxy/dynamic_router.go @@ -0,0 +1,39 @@ +// Package proxy provides dynamic routing capabilities for HTTP servers. +// +// This file contains the implementation of a dynamic router that allows +// changing the HTTP handler at runtime in a thread-safe manner. It is +// particularly useful for proxy servers that need to update their routing +// logic wihtout restarting the server. +package proxy + +import ( + "net/http" + "sync" +) + +type dynamicRouter struct { + mu sync.RWMutex + router http.Handler +} + +func NewDynamicRouter() *dynamicRouter { + return &dynamicRouter{} +} + +func (d *dynamicRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + d.mu.RLock() + + router := d.router + + d.mu.RUnlock() + + router.ServeHTTP(w, r) +} + +func (d *dynamicRouter) SetRouter(router http.Handler) { + d.mu.Lock() + + d.router = router + + d.mu.Unlock() +} diff --git a/pkg/api/error.go b/pkg/api/error.go index 04d857d5c..8f66a6d40 100644 --- a/pkg/api/error.go +++ b/pkg/api/error.go @@ -3,6 +3,9 @@ package api import ( "errors" "net/http" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -11,6 +14,13 @@ var ( ) func ErrToStatus(err error) int { + // If the error is a gRPC error, we can extract the status code. + if status, ok := status.FromError(err); ok { + if status.Code() == codes.Unavailable { + return http.StatusServiceUnavailable + } + } + switch errors.Unwrap(err) { case ErrBadRequest: return http.StatusBadRequest diff --git a/pkg/openapi/api_model_registry_service.go b/pkg/openapi/api_model_registry_service.go index 4904ad210..86b9ff40b 100644 --- a/pkg/openapi/api_model_registry_service.go +++ b/pkg/openapi/api_model_registry_service.go @@ -150,6 +150,17 @@ func (a *ModelRegistryServiceAPIService) CreateArtifactExecute(r ApiCreateArtifa } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -309,6 +320,17 @@ func (a *ModelRegistryServiceAPIService) CreateEnvironmentInferenceServiceExecut } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -453,6 +475,17 @@ func (a *ModelRegistryServiceAPIService) CreateInferenceServiceExecute(r ApiCrea } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -612,6 +645,17 @@ func (a *ModelRegistryServiceAPIService) CreateInferenceServiceServeExecute(r Ap } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -756,6 +800,17 @@ func (a *ModelRegistryServiceAPIService) CreateModelArtifactExecute(r ApiCreateM } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -900,6 +955,17 @@ func (a *ModelRegistryServiceAPIService) CreateModelVersionExecute(r ApiCreateMo } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -1044,6 +1110,17 @@ func (a *ModelRegistryServiceAPIService) CreateRegisteredModelExecute(r ApiCreat } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -1203,6 +1280,17 @@ func (a *ModelRegistryServiceAPIService) CreateRegisteredModelVersionExecute(r A } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -1347,6 +1435,17 @@ func (a *ModelRegistryServiceAPIService) CreateServingEnvironmentExecute(r ApiCr } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -1520,6 +1619,17 @@ func (a *ModelRegistryServiceAPIService) FindArtifactExecute(r ApiFindArtifactRe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -1693,6 +1803,17 @@ func (a *ModelRegistryServiceAPIService) FindInferenceServiceExecute(r ApiFindIn } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -1866,6 +1987,17 @@ func (a *ModelRegistryServiceAPIService) FindModelArtifactExecute(r ApiFindModel } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2039,6 +2171,17 @@ func (a *ModelRegistryServiceAPIService) FindModelVersionExecute(r ApiFindModelV } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2191,6 +2334,17 @@ func (a *ModelRegistryServiceAPIService) FindRegisteredModelExecute(r ApiFindReg } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2343,6 +2497,17 @@ func (a *ModelRegistryServiceAPIService) FindServingEnvironmentExecute(r ApiFind } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2479,6 +2644,17 @@ func (a *ModelRegistryServiceAPIService) GetArtifactExecute(r ApiGetArtifactRequ } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2662,6 +2838,17 @@ func (a *ModelRegistryServiceAPIService) GetArtifactsExecute(r ApiGetArtifactsRe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2858,6 +3045,17 @@ func (a *ModelRegistryServiceAPIService) GetEnvironmentInferenceServicesExecute( } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -2994,6 +3192,17 @@ func (a *ModelRegistryServiceAPIService) GetInferenceServiceExecute(r ApiGetInfe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -3130,6 +3339,17 @@ func (a *ModelRegistryServiceAPIService) GetInferenceServiceModelExecute(r ApiGe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -3326,6 +3546,17 @@ func (a *ModelRegistryServiceAPIService) GetInferenceServiceServesExecute(r ApiG } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -3462,6 +3693,17 @@ func (a *ModelRegistryServiceAPIService) GetInferenceServiceVersionExecute(r Api } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -3645,6 +3887,17 @@ func (a *ModelRegistryServiceAPIService) GetInferenceServicesExecute(r ApiGetInf } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -3781,6 +4034,17 @@ func (a *ModelRegistryServiceAPIService) GetModelArtifactExecute(r ApiGetModelAr } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -3964,6 +4228,17 @@ func (a *ModelRegistryServiceAPIService) GetModelArtifactsExecute(r ApiGetModelA } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -4100,6 +4375,17 @@ func (a *ModelRegistryServiceAPIService) GetModelVersionExecute(r ApiGetModelVer } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -4294,6 +4580,17 @@ func (a *ModelRegistryServiceAPIService) GetModelVersionArtifactsExecute(r ApiGe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -4455,6 +4752,17 @@ func (a *ModelRegistryServiceAPIService) GetModelVersionsExecute(r ApiGetModelVe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -4591,6 +4899,17 @@ func (a *ModelRegistryServiceAPIService) GetRegisteredModelExecute(r ApiGetRegis } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -4787,6 +5106,17 @@ func (a *ModelRegistryServiceAPIService) GetRegisteredModelVersionsExecute(r Api } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -4948,6 +5278,17 @@ func (a *ModelRegistryServiceAPIService) GetRegisteredModelsExecute(r ApiGetRegi } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -5084,6 +5425,17 @@ func (a *ModelRegistryServiceAPIService) GetServingEnvironmentExecute(r ApiGetSe } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -5245,6 +5597,17 @@ func (a *ModelRegistryServiceAPIService) GetServingEnvironmentsExecute(r ApiGetS } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -5404,6 +5767,17 @@ func (a *ModelRegistryServiceAPIService) UpdateArtifactExecute(r ApiUpdateArtifa } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -5563,6 +5937,17 @@ func (a *ModelRegistryServiceAPIService) UpdateInferenceServiceExecute(r ApiUpda } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -5722,6 +6107,17 @@ func (a *ModelRegistryServiceAPIService) UpdateModelArtifactExecute(r ApiUpdateM } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -5881,6 +6277,17 @@ func (a *ModelRegistryServiceAPIService) UpdateModelVersionExecute(r ApiUpdateMo } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -6040,6 +6447,17 @@ func (a *ModelRegistryServiceAPIService) UpdateRegisteredModelExecute(r ApiUpdat } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -6199,6 +6617,17 @@ func (a *ModelRegistryServiceAPIService) UpdateServingEnvironmentExecute(r ApiUp } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } @@ -6358,6 +6787,17 @@ func (a *ModelRegistryServiceAPIService) UpsertModelVersionArtifactExecute(r Api } newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 503 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v } return localVarReturnValue, localVarHTTPResponse, newErr } From 913e5f5e87fefd0e8e5fc0766dd3c7e7e64bf821 Mon Sep 17 00:00:00 2001 From: Alessio Pragliola Date: Tue, 14 Jan 2025 22:01:27 +0100 Subject: [PATCH 2/3] chore(proxy): simplify code by removing generic function Signed-off-by: Alessio Pragliola --- cmd/proxy.go | 31 +++++++++++++++---- internal/grpc/backoff.go | 67 ---------------------------------------- 2 files changed, 25 insertions(+), 73 deletions(-) delete mode 100644 internal/grpc/backoff.go diff --git a/cmd/proxy.go b/cmd/proxy.go index 6ddb62279..c2467b38f 100644 --- a/cmd/proxy.go +++ b/cmd/proxy.go @@ -4,19 +4,26 @@ import ( "fmt" "net/http" "sync" + "time" "github.com/golang/glog" - mrGrpc "github.com/kubeflow/model-registry/internal/grpc" "github.com/kubeflow/model-registry/internal/mlmdtypes" "github.com/kubeflow/model-registry/internal/proxy" "github.com/kubeflow/model-registry/internal/server/openapi" "github.com/kubeflow/model-registry/pkg/core" "github.com/spf13/cobra" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" ) -const mlmdUnavailableMessage = "MLMD server is down or unavailable. Please check that the database is reachable and try again later." +const ( + // mlmdUnavailableMessage is the message returned when the MLMD server is down or unavailable. + mlmdUnavailableMessage = "MLMD server is down or unavailable. Please check that the database is reachable and try again later." + // maxGRPCRetryAttempts is the maximum number of attempts to retry GRPC requests to the MLMD server. + maxGRPCRetryAttempts = 25 // 25 attempts with incremental backoff (1s, 2s, 3s, ..., 25s) it's ~5 minutes +) // proxyCmd represents the proxy command var proxyCmd = &cobra.Command{ @@ -59,12 +66,24 @@ func runProxyServer(cmd *cobra.Command, args []string) error { mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() - _, err = mrGrpc.RetryOnGRPCError[map[string]int64](mlmdtypes.CreateMLMDTypes, conn, mlmdTypeNamesConfig) - if err != nil { - errChan <- fmt.Errorf("error creating MLMD types: %w", err) + // Backoff and retry GRPC requests to the MLMD server, until the server + // becomes available or the maximum number of attempts is reached. + for i := 0; i < maxGRPCRetryAttempts; i++ { + _, err := mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig) + if err == nil { + break + } - return + st, ok := status.FromError(err) + if !ok || st.Code() != codes.Unavailable { + errChan <- fmt.Errorf("error creating MLMD types: %w", err) + + return + } + + time.Sleep(time.Duration(i+1) * time.Second) } + service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) if err != nil { errChan <- fmt.Errorf("error creating core service: %w", err) diff --git a/internal/grpc/backoff.go b/internal/grpc/backoff.go deleted file mode 100644 index dd68b7544..000000000 --- a/internal/grpc/backoff.go +++ /dev/null @@ -1,67 +0,0 @@ -package grpc - -import ( - "fmt" - "reflect" - "time" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -const ( - // maxRetryAttempts is the maximum number of times to retry a request. - maxRetryAttempts = 25 -) - -var retryableStatusCodes = map[codes.Code]bool{ - codes.Unavailable: true, -} - -func RetryOnGRPCError[T any](funcToRetry any, funcParams ...any) (T, error) { - var outErr error - var res T - - fnVal := reflect.ValueOf(funcToRetry) - if fnVal.Kind() != reflect.Func { - return res, fmt.Errorf("grpc retry error: function parameter is not a function") - } - - if len(funcParams) != fnVal.Type().NumIn() { - return res, fmt.Errorf("grpc retry error: function parameters count mismatch") - } - - inputs := make([]reflect.Value, len(funcParams)) - for i, param := range funcParams { - inputs[i] = reflect.ValueOf(param) - } - - for i := 0; i < maxRetryAttempts; i++ { - outs := fnVal.Call(inputs) - - if len(outs) != 2 { - return res, fmt.Errorf("grpc retry error: function from input does not return 2 values") - } - - res = outs[0].Interface().(T) - errI := outs[1].Interface() - if errI == nil { - outErr = nil - } else { - outErr = errI.(error) - } - - if status, ok := status.FromError(outErr); ok { - if !retryableStatusCodes[status.Code()] { - break - } - } else { - break - } - - backoff := time.Duration(i+1) * time.Second - time.Sleep(backoff) - } - - return res, outErr -} From ba92d73836c950ca6d14250b96015b999985cfd7 Mon Sep 17 00:00:00 2001 From: Alessio Pragliola Date: Tue, 28 Jan 2025 21:18:40 +0100 Subject: [PATCH 3/3] fix(proxy): prevent race condition on err channel Signed-off-by: Alessio Pragliola --- cmd/proxy.go | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/cmd/proxy.go b/cmd/proxy.go index c2467b38f..092603665 100644 --- a/cmd/proxy.go +++ b/cmd/proxy.go @@ -3,7 +3,6 @@ package cmd import ( "fmt" "net/http" - "sync" "time" "github.com/golang/glog" @@ -39,9 +38,9 @@ hostname and port where it listens.'`, func runProxyServer(cmd *cobra.Command, args []string) error { var conn *grpc.ClientConn var err error - var wg sync.WaitGroup - errChan := make(chan error, 1) + errMLMDChan := make(chan error, 1) + errProxyChan := make(chan error, 1) router := proxy.NewDynamicRouter() @@ -53,13 +52,13 @@ func runProxyServer(cmd *cobra.Command, args []string) error { // we can start the proxy server and start serving requests while we wait // for the connection to be established. go func() { - defer close(errChan) + defer close(errMLMDChan) mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort) glog.Infof("connecting to MLMD server %s..", mlmdAddr) conn, err = grpc.NewClient(mlmdAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - errChan <- fmt.Errorf("error dialing connection to mlmd server %s: %w", mlmdAddr, err) + errMLMDChan <- fmt.Errorf("error dialing connection to mlmd server %s: %w", mlmdAddr, err) return } @@ -76,7 +75,7 @@ func runProxyServer(cmd *cobra.Command, args []string) error { st, ok := status.FromError(err) if !ok || st.Code() != codes.Unavailable { - errChan <- fmt.Errorf("error creating MLMD types: %w", err) + errMLMDChan <- fmt.Errorf("error creating MLMD types: %w", err) return } @@ -86,7 +85,7 @@ func runProxyServer(cmd *cobra.Command, args []string) error { service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) if err != nil { - errChan <- fmt.Errorf("error creating core service: %w", err) + errMLMDChan <- fmt.Errorf("error creating core service: %w", err) return } @@ -99,18 +98,16 @@ func runProxyServer(cmd *cobra.Command, args []string) error { glog.Infof("connected to MLMD server") }() - wg.Add(1) - // Start the proxy server in a separate goroutine so that we can handle // errors from both the proxy server and the connection to the MLMD server. go func() { - defer wg.Done() + defer close(errProxyChan) glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port) err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) if err != nil { - errChan <- err + errProxyChan <- fmt.Errorf("error starting proxy server: %w", err) } }() @@ -122,15 +119,25 @@ func runProxyServer(cmd *cobra.Command, args []string) error { } }() - err = <-errChan - if err != nil { - return fmt.Errorf("error starting proxy server: %w", err) - } + // Wait for either the MLMD server connection or the proxy server to return an error + // or for both to finish successfully. + for { + select { + case err := <-errMLMDChan: + if err != nil { + return err + } - // Wait for the proxy server to finish serving requests. - wg.Wait() + case err := <-errProxyChan: + if err != nil { + return err + } + } - return nil + if errMLMDChan == nil && errProxyChan == nil { + return nil + } + } } func init() {