diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 0befdc450..9fc516cf4 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -35,6 +35,11 @@ If you want to use a different port, mock kubernetes client or model registry cl ```shell make run PORT=8000 MOCK_K8S_CLIENT=true MOCK_MR_CLIENT=true ``` +If you want to change the log level on deployment, add the LOG_LEVEL argument when running, supported levels are: ERROR, WARN, INFO, DEBUG. The default level is INFO. +```shell +# Run with debug logging +make run LOG_LEVEL=DEBUG +``` # Building and Deploying diff --git a/clients/ui/bff/internal/api/healthcheck__handler_test.go b/clients/ui/bff/internal/api/healthcheck__handler_test.go index 0212a58cd..e830afb21 100644 --- a/clients/ui/bff/internal/api/healthcheck__handler_test.go +++ b/clients/ui/bff/internal/api/healthcheck__handler_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/kubeflow/model-registry/ui/bff/internal/config" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" @@ -26,7 +27,7 @@ func TestHealthCheckHandler(t *testing.T) { rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, HealthCheckPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) req = req.WithContext(ctx) assert.NoError(t, err) diff --git a/clients/ui/bff/internal/api/healthcheck_handler.go b/clients/ui/bff/internal/api/healthcheck_handler.go index 6ee2049a2..4692d6be7 100644 --- a/clients/ui/bff/internal/api/healthcheck_handler.go +++ b/clients/ui/bff/internal/api/healthcheck_handler.go @@ -3,12 +3,13 @@ package api import ( "errors" "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "net/http" ) func (app *App) HealthcheckHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - userId, ok := r.Context().Value(KubeflowUserIdKey).(string) + userId, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || userId == "" { app.serverErrorResponse(w, r, errors.New("failed to retrieve kubeflow-userid from context")) return diff --git a/clients/ui/bff/internal/api/middleware.go b/clients/ui/bff/internal/api/middleware.go index 03436ed66..bec3d0595 100644 --- a/clients/ui/bff/internal/api/middleware.go +++ b/clients/ui/bff/internal/api/middleware.go @@ -7,30 +7,13 @@ import ( "github.com/google/uuid" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/ui/bff/internal/config" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "log/slog" "net/http" "strings" ) -type contextKey string - -const ( - ModelRegistryHttpClientKey contextKey = "ModelRegistryHttpClientKey" - NamespaceHeaderParameterKey contextKey = "namespace" - - //Kubeflow authorization operates using custom authentication headers: - // Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time - // but it's supported on Model Registry BFF - KubeflowUserIdKey contextKey = "kubeflowUserId" // kubeflow-userid :contains the user's email address - KubeflowUserIDHeader = "kubeflow-userid" - KubeflowUserGroupsKey contextKey = "kubeflowUserGroups" // kubeflow-groups : Holds a comma-separated list of user groups - KubeflowUserGroupsIdHeader = "kubeflow-groups" - - TraceIdKey contextKey = "TraceIdKey" - TraceLoggerKey contextKey = "TraceLoggerKey" -) - func (app *App) RecoverPanic(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { @@ -53,8 +36,8 @@ func (app *App) InjectUserHeaders(next http.Handler) http.Handler { return } - userIdHeader := r.Header.Get(KubeflowUserIDHeader) - userGroupsHeader := r.Header.Get(KubeflowUserGroupsIdHeader) + userIdHeader := r.Header.Get(constants.KubeflowUserIDHeader) + userGroupsHeader := r.Header.Get(constants.KubeflowUserGroupsIdHeader) //`kubeflow-userid`: Contains the user's email address. if userIdHeader == "" { app.badRequestResponse(w, r, errors.New("missing required header: kubeflow-userid")) @@ -74,8 +57,8 @@ func (app *App) InjectUserHeaders(next http.Handler) http.Handler { } ctx := r.Context() - ctx = context.WithValue(ctx, KubeflowUserIdKey, userIdHeader) - ctx = context.WithValue(ctx, KubeflowUserGroupsKey, userGroups) + ctx = context.WithValue(ctx, constants.KubeflowUserIdKey, userIdHeader) + ctx = context.WithValue(ctx, constants.KubeflowUserGroupsKey, userGroups) next.ServeHTTP(w, r.WithContext(ctx)) }) @@ -95,12 +78,12 @@ func (app *App) EnableTelemetry(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Adds a unique id to the context to allow tracing of requests traceId := uuid.NewString() - ctx := context.WithValue(r.Context(), TraceIdKey, traceId) + ctx := context.WithValue(r.Context(), constants.TraceIdKey, traceId) // logger will only be nil in tests. if app.logger != nil { traceLogger := app.logger.With(slog.String("trace_id", traceId)) - ctx = context.WithValue(ctx, TraceLoggerKey, traceLogger) + ctx = context.WithValue(ctx, constants.TraceLoggerKey, traceLogger) if traceLogger.Enabled(ctx, slog.LevelDebug) { cloneBody, err := integrations.CloneBody(r) @@ -121,12 +104,12 @@ func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, h modelRegistryID := ps.ByName(ModelRegistryId) - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in the context")) } - modelRegistryBaseURL, err := resolveModelRegistryURL(namespace, modelRegistryID, app.kubernetesClient, app.config) + modelRegistryBaseURL, err := resolveModelRegistryURL(r.Context(), namespace, modelRegistryID, app.kubernetesClient, app.config) if err != nil { app.notFoundResponse(w, r) return @@ -135,7 +118,7 @@ func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, h // Set up a child logger for the rest client that automatically adds the request id to all statements for // tracing. restClientLogger := app.logger - traceId, ok := r.Context().Value(TraceIdKey).(string) + traceId, ok := r.Context().Value(constants.TraceIdKey).(string) if app.logger != nil { if ok { restClientLogger = app.logger.With(slog.String("trace_id", traceId)) @@ -149,14 +132,14 @@ func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, h app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err)) return } - ctx := context.WithValue(r.Context(), ModelRegistryHttpClientKey, client) + ctx := context.WithValue(r.Context(), constants.ModelRegistryHttpClientKey, client) next(w, r.WithContext(ctx), ps) } } -func resolveModelRegistryURL(namespace string, serviceName string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { +func resolveModelRegistryURL(sessionCtx context.Context, namespace string, serviceName string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { - serviceDetails, err := client.GetServiceDetailsByName(namespace, serviceName) + serviceDetails, err := client.GetServiceDetailsByName(sessionCtx, namespace, serviceName) if err != nil { return "", err } @@ -172,13 +155,13 @@ func resolveModelRegistryURL(namespace string, serviceName string, client integr func (app *App) AttachNamespace(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - namespace := r.URL.Query().Get(string(NamespaceHeaderParameterKey)) + namespace := r.URL.Query().Get(string(constants.NamespaceHeaderParameterKey)) if namespace == "" { - app.badRequestResponse(w, r, fmt.Errorf("missing required query parameter: %s", NamespaceHeaderParameterKey)) + app.badRequestResponse(w, r, fmt.Errorf("missing required query parameter: %s", constants.NamespaceHeaderParameterKey)) return } - ctx := context.WithValue(r.Context(), NamespaceHeaderParameterKey, namespace) + ctx := context.WithValue(r.Context(), constants.NamespaceHeaderParameterKey, namespace) r = r.WithContext(ctx) next(w, r, ps) @@ -187,19 +170,19 @@ func (app *App) AttachNamespace(next func(http.ResponseWriter, *http.Request, ht func (app *App) PerformSARonGetListServicesByNamespace(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - user, ok := r.Context().Value(KubeflowUserIdKey).(string) + user, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || user == "" { app.badRequestResponse(w, r, fmt.Errorf("missing user in context")) return } - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) return } var userGroups []string - if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + if groups, ok := r.Context().Value(constants.KubeflowUserGroupsKey).([]string); ok { userGroups = groups } else { userGroups = []string{} @@ -222,13 +205,13 @@ func (app *App) PerformSARonGetListServicesByNamespace(next func(http.ResponseWr func (app *App) PerformSARonSpecificService(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - user, ok := r.Context().Value(KubeflowUserIdKey).(string) + user, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || user == "" { app.badRequestResponse(w, r, fmt.Errorf("missing user in context")) return } - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) return @@ -241,7 +224,7 @@ func (app *App) PerformSARonSpecificService(next func(http.ResponseWriter, *http } var userGroups []string - if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + if groups, ok := r.Context().Value(constants.KubeflowUserGroupsKey).([]string); ok { userGroups = groups } else { userGroups = []string{} diff --git a/clients/ui/bff/internal/api/model_registry_handler.go b/clients/ui/bff/internal/api/model_registry_handler.go index 1600d0045..d5ee69e0c 100644 --- a/clients/ui/bff/internal/api/model_registry_handler.go +++ b/clients/ui/bff/internal/api/model_registry_handler.go @@ -3,6 +3,7 @@ package api import ( "fmt" "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" ) @@ -11,12 +12,12 @@ type ModelRegistryListEnvelope Envelope[[]models.ModelRegistryModel, None] func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in the context")) } - registries, err := app.repositories.ModelRegistry.GetAllModelRegistries(app.kubernetesClient, namespace) + registries, err := app.repositories.ModelRegistry.GetAllModelRegistries(r.Context(), app.kubernetesClient, namespace) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/model_registry_handler_test.go b/clients/ui/bff/internal/api/model_registry_handler_test.go index 872121ce0..2e37d080c 100644 --- a/clients/ui/bff/internal/api/model_registry_handler_test.go +++ b/clients/ui/bff/internal/api/model_registry_handler_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" + "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" . "github.com/onsi/ginkgo/v2" @@ -28,7 +30,8 @@ var _ = Describe("TestModelRegistryHandler", func() { requestPath := fmt.Sprintf(" %s?namespace=kubeflow", ModelRegistryListPath) req, err := http.NewRequest(http.MethodGet, requestPath, nil) - ctx := context.WithValue(req.Context(), NamespaceHeaderParameterKey, "kubeflow") + ctx := mocks.NewMockSessionContext(req.Context()) + ctx = context.WithValue(ctx, constants.NamespaceHeaderParameterKey, "kubeflow") req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) diff --git a/clients/ui/bff/internal/api/model_versions_handler.go b/clients/ui/bff/internal/api/model_versions_handler.go index a945d0492..c69e3fa9d 100644 --- a/clients/ui/bff/internal/api/model_versions_handler.go +++ b/clients/ui/bff/internal/api/model_versions_handler.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/validation" "net/http" @@ -19,7 +20,7 @@ type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None] func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -47,7 +48,7 @@ func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, p } func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -101,7 +102,7 @@ func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -151,7 +152,7 @@ func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -174,7 +175,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, } func (app *App) CreateModelArtifactByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return diff --git a/clients/ui/bff/internal/api/namespaces_handler.go b/clients/ui/bff/internal/api/namespaces_handler.go index 80fb47014..88550531a 100644 --- a/clients/ui/bff/internal/api/namespaces_handler.go +++ b/clients/ui/bff/internal/api/namespaces_handler.go @@ -2,6 +2,7 @@ package api import ( "errors" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" @@ -12,14 +13,14 @@ type NamespacesEnvelope Envelope[[]models.NamespaceModel, None] func (app *App) GetNamespacesHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - userId, ok := r.Context().Value(KubeflowUserIdKey).(string) + userId, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || userId == "" { app.serverErrorResponse(w, r, errors.New("failed to retrieve kubeflow-userid from context")) return } var userGroups []string - if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + if groups, ok := r.Context().Value(constants.KubeflowUserGroupsKey).([]string); ok { userGroups = groups } else { userGroups = []string{} diff --git a/clients/ui/bff/internal/api/namespaces_handler_test.go b/clients/ui/bff/internal/api/namespaces_handler_test.go index b4869058f..505956975 100644 --- a/clients/ui/bff/internal/api/namespaces_handler_test.go +++ b/clients/ui/bff/internal/api/namespaces_handler_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/kubeflow/model-registry/ui/bff/internal/config" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" @@ -31,7 +32,7 @@ var _ = Describe("TestNamespacesHandler", func() { It("should return only dora-namespace for doraNonAdmin@example.com", func() { By("creating the HTTP request with the kubeflow-userid header") req, err := http.NewRequest(http.MethodGet, NamespaceListPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.DoraNonAdminUser) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.DoraNonAdminUser) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) rr := httptest.NewRecorder() @@ -57,7 +58,7 @@ var _ = Describe("TestNamespacesHandler", func() { It("should return all namespaces for user@example.com", func() { By("creating the HTTP request with the kubeflow-userid header") req, err := http.NewRequest(http.MethodGet, NamespaceListPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) req.Header.Set("kubeflow-userid", "user@example.com") @@ -87,7 +88,7 @@ var _ = Describe("TestNamespacesHandler", func() { It("should return no namespaces for non-existent user", func() { By("creating the HTTP request with a non-existent kubeflow-userid") req, err := http.NewRequest(http.MethodGet, NamespaceListPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, "nonexistent@example.com") + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, "nonexistent@example.com") req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) rr := httptest.NewRecorder() diff --git a/clients/ui/bff/internal/api/registered_models_handler.go b/clients/ui/bff/internal/api/registered_models_handler.go index 98daef1cd..25dbb58ea 100644 --- a/clients/ui/bff/internal/api/registered_models_handler.go +++ b/clients/ui/bff/internal/api/registered_models_handler.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/validation" "net/http" @@ -16,7 +17,7 @@ type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None] func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -39,7 +40,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req } func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -93,7 +94,7 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -121,7 +122,7 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request } func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -171,7 +172,7 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -195,7 +196,7 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit } func (app *App) CreateModelVersionForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return diff --git a/clients/ui/bff/internal/api/test_utils.go b/clients/ui/bff/internal/api/test_utils.go index 952a03337..1f8f0cc93 100644 --- a/clients/ui/bff/internal/api/test_utils.go +++ b/clients/ui/bff/internal/api/test_utils.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" k8s "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" @@ -46,16 +47,16 @@ func setupApiTest[T any](method string, url string, body interface{}, k8sClient } // Set the kubeflow-userid header - req.Header.Set(KubeflowUserIDHeader, kubeflowUserIDHeaderValue) + req.Header.Set(constants.KubeflowUserIDHeader, kubeflowUserIDHeaderValue) - ctx := req.Context() - ctx = context.WithValue(ctx, ModelRegistryHttpClientKey, mockClient) - ctx = context.WithValue(ctx, KubeflowUserIdKey, kubeflowUserIDHeaderValue) - ctx = context.WithValue(ctx, NamespaceHeaderParameterKey, namespace) + ctx := mocks.NewMockSessionContext(req.Context()) + ctx = context.WithValue(ctx, constants.ModelRegistryHttpClientKey, mockClient) + ctx = context.WithValue(ctx, constants.KubeflowUserIdKey, kubeflowUserIDHeaderValue) + ctx = context.WithValue(ctx, constants.NamespaceHeaderParameterKey, namespace) mrHttpClient := k8s.HTTPClient{ ModelRegistryID: "model-registry", } - ctx = context.WithValue(ctx, ModelRegistryHttpClientKey, mrHttpClient) + ctx = context.WithValue(ctx, constants.ModelRegistryHttpClientKey, mrHttpClient) req = req.WithContext(ctx) rr := httptest.NewRecorder() diff --git a/clients/ui/bff/internal/api/user_handler.go b/clients/ui/bff/internal/api/user_handler.go index 9ec135ccc..1359ab18b 100644 --- a/clients/ui/bff/internal/api/user_handler.go +++ b/clients/ui/bff/internal/api/user_handler.go @@ -3,6 +3,7 @@ package api import ( "errors" "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" ) @@ -11,7 +12,7 @@ type UserEnvelope Envelope[*models.User, None] func (app *App) UserHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - userId, ok := r.Context().Value(KubeflowUserIdKey).(string) + userId, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || userId == "" { app.serverErrorResponse(w, r, errors.New("failed to retrieve kubeflow-userid from context")) return diff --git a/clients/ui/bff/internal/api/user_handler_test.go b/clients/ui/bff/internal/api/user_handler_test.go index 13cbf95a8..2927fa101 100644 --- a/clients/ui/bff/internal/api/user_handler_test.go +++ b/clients/ui/bff/internal/api/user_handler_test.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "io" "net/http" @@ -34,7 +35,7 @@ var _ = Describe("TestUserHandler", func() { It("should show that KubeflowUserIDHeaderValue (user@example.com) is a cluster-admin", func() { By("creating the http request") req, err := http.NewRequest(http.MethodGet, UserPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) @@ -62,7 +63,7 @@ var _ = Describe("TestUserHandler", func() { It("should show that DoraNonAdminUser (doraNonAdmin@example.com) is not a cluster-admin", func() { By("creating the http request") req, err := http.NewRequest(http.MethodGet, UserPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, DoraNonAdminUser) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, DoraNonAdminUser) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) @@ -92,7 +93,7 @@ var _ = Describe("TestUserHandler", func() { By("creating the http request") req, err := http.NewRequest(http.MethodGet, UserPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, randomUser) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, randomUser) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) diff --git a/clients/ui/bff/internal/constants/keys.go b/clients/ui/bff/internal/constants/keys.go new file mode 100644 index 000000000..3051679f5 --- /dev/null +++ b/clients/ui/bff/internal/constants/keys.go @@ -0,0 +1,19 @@ +package constants + +type contextKey string + +const ( + ModelRegistryHttpClientKey contextKey = "ModelRegistryHttpClientKey" + NamespaceHeaderParameterKey contextKey = "namespace" + + //Kubeflow authorization operates using custom authentication headers: + // Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time + // but it's supported on Model Registry BFF + KubeflowUserIdKey contextKey = "kubeflowUserId" // kubeflow-userid :contains the user's email address + KubeflowUserIDHeader = "kubeflow-userid" + KubeflowUserGroupsKey contextKey = "kubeflowUserGroups" // kubeflow-groups : Holds a comma-separated list of user groups + KubeflowUserGroupsIdHeader = "kubeflow-groups" + + TraceIdKey contextKey = "TraceIdKey" + TraceLoggerKey contextKey = "TraceLoggerKey" +) diff --git a/clients/ui/bff/internal/integrations/k8s.go b/clients/ui/bff/internal/integrations/k8s.go index 9b89bb02c..19e52ca51 100644 --- a/clients/ui/bff/internal/integrations/k8s.go +++ b/clients/ui/bff/internal/integrations/k8s.go @@ -3,6 +3,7 @@ package integrations import ( "context" "fmt" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" helper "github.com/kubeflow/model-registry/ui/bff/internal/helpers" authv1 "k8s.io/api/authorization/v1" corev1 "k8s.io/api/core/v1" @@ -22,9 +23,9 @@ import ( const ComponentLabelValue = "model-registry" type KubernetesClientInterface interface { - GetServiceNames(namespace string) ([]string, error) - GetServiceDetailsByName(namespace string, serviceName string) (ServiceDetails, error) - GetServiceDetails(namespace string) ([]ServiceDetails, error) + GetServiceNames(sessionCtx context.Context, namespace string) ([]string, error) + GetServiceDetailsByName(sessionCtx context.Context, namespace string, serviceName string) (ServiceDetails, error) + GetServiceDetails(sessionCtx context.Context, namespace string) ([]ServiceDetails, error) BearerToken() (string, error) Shutdown(ctx context.Context, logger *slog.Logger) error IsInCluster() bool @@ -152,8 +153,8 @@ func (kc *KubernetesClient) BearerToken() (string, error) { return kc.Token, nil } -func (kc *KubernetesClient) GetServiceNames(namespace string) ([]string, error) { - services, err := kc.GetServiceDetails(namespace) +func (kc *KubernetesClient) GetServiceNames(sessionCtx context.Context, namespace string) ([]string, error) { + services, err := kc.GetServiceDetails(sessionCtx, namespace) if err != nil { return nil, err } @@ -166,7 +167,7 @@ func (kc *KubernetesClient) GetServiceNames(namespace string) ([]string, error) return names, nil } -func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetails, error) { +func (kc *KubernetesClient) GetServiceDetails(sessionCtx context.Context, namespace string) ([]ServiceDetails, error) { if namespace == "" { return nil, fmt.Errorf("namespace cannot be empty") @@ -175,6 +176,8 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + sessionLogger := sessionCtx.Value(constants.TraceLoggerKey).(*slog.Logger) + serviceList := &corev1.ServiceList{} labelSelector := labels.SelectorFromSet(labels.Set{ @@ -202,12 +205,12 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail } } if !hasHTTPPort { - kc.Logger.Error("service missing HTTP port", "serviceName", service.Name) + sessionLogger.Error("service missing HTTP port", "serviceName", service.Name) continue } if service.Spec.ClusterIP == "" { - kc.Logger.Error("service missing valid ClusterIP", "serviceName", service.Name) + sessionLogger.Error("service missing valid ClusterIP", "serviceName", service.Name) continue } @@ -220,11 +223,11 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail } if displayName == "" { - kc.Logger.Warn("service missing displayName annotation", "serviceName", service.Name) + sessionLogger.Warn("service missing displayName annotation", "serviceName", service.Name) } if description == "" { - kc.Logger.Warn("service missing description annotation", "serviceName", service.Name) + sessionLogger.Warn("service missing description annotation", "serviceName", service.Name) } serviceDetails := ServiceDetails{ @@ -242,8 +245,8 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail return services, nil } -func (kc *KubernetesClient) GetServiceDetailsByName(namespace string, serviceName string) (ServiceDetails, error) { - services, err := kc.GetServiceDetails(namespace) +func (kc *KubernetesClient) GetServiceDetailsByName(sessionCtx context.Context, namespace string, serviceName string) (ServiceDetails, error) { + services, err := kc.GetServiceDetails(sessionCtx, namespace) if err != nil { return ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) } diff --git a/clients/ui/bff/internal/mocks/k8s_mock.go b/clients/ui/bff/internal/mocks/k8s_mock.go index ce2b3e61a..7a74e65df 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock.go +++ b/clients/ui/bff/internal/mocks/k8s_mock.go @@ -185,8 +185,8 @@ func setupMock(mockK8sClient client.Client, ctx context.Context) error { return nil } -func (m *KubernetesClientMock) GetServiceDetails(namespace string) ([]k8s.ServiceDetails, error) { - originalServices, err := m.KubernetesClient.GetServiceDetails(namespace) +func (m *KubernetesClientMock) GetServiceDetails(sessionCtx context.Context, namespace string) ([]k8s.ServiceDetails, error) { + originalServices, err := m.KubernetesClient.GetServiceDetails(sessionCtx, namespace) if err != nil { return nil, fmt.Errorf("failed to get service details: %w", err) } @@ -199,8 +199,8 @@ func (m *KubernetesClientMock) GetServiceDetails(namespace string) ([]k8s.Servic return originalServices, nil } -func (m *KubernetesClientMock) GetServiceDetailsByName(namespace string, serviceName string) (k8s.ServiceDetails, error) { - originalService, err := m.KubernetesClient.GetServiceDetailsByName(namespace, serviceName) +func (m *KubernetesClientMock) GetServiceDetailsByName(sessionCtx context.Context, namespace string, serviceName string) (k8s.ServiceDetails, error) { + originalService, err := m.KubernetesClient.GetServiceDetailsByName(sessionCtx, namespace, serviceName) if err != nil { return k8s.ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) } diff --git a/clients/ui/bff/internal/mocks/k8s_mock_test.go b/clients/ui/bff/internal/mocks/k8s_mock_test.go index e236326aa..3f84783fa 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock_test.go +++ b/clients/ui/bff/internal/mocks/k8s_mock_test.go @@ -11,7 +11,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the get all service successfully", func() { By("getting service details") - services, err := k8sClient.GetServiceDetails("kubeflow") + services, err := k8sClient.GetServiceDetails(NewMockSessionContextNoParent(), "kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to create HTTP request") By("checking that all services have the modified ClusterIP and HTTPPort") @@ -37,7 +37,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the service details by name", func() { By("getting service by name") - service, err := k8sClient.GetServiceDetailsByName("dora-namespace", "model-registry-dora") + service, err := k8sClient.GetServiceDetailsByName(NewMockSessionContextNoParent(), "dora-namespace", "model-registry-dora") Expect(err).NotTo(HaveOccurred(), "Failed to create k8s request") By("checking that service details are correct") @@ -49,7 +49,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the services names", func() { By("getting service by name") - services, err := k8sClient.GetServiceNames("kubeflow") + services, err := k8sClient.GetServiceNames(NewMockSessionContextNoParent(), "kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to create HTTP request") By("checking that service details are correct") diff --git a/clients/ui/bff/internal/mocks/static_data_mock.go b/clients/ui/bff/internal/mocks/static_data_mock.go index bb0408c27..252a285e8 100644 --- a/clients/ui/bff/internal/mocks/static_data_mock.go +++ b/clients/ui/bff/internal/mocks/static_data_mock.go @@ -1,7 +1,12 @@ package mocks import ( + "context" + "github.com/google/uuid" "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" + "log/slog" + "os" ) func GetRegisteredModelMocks() []openapi.RegisteredModel { @@ -200,3 +205,19 @@ func newCustomProperties() *map[string]openapi.MetadataValue { return &result } + +func NewMockSessionContext(parent context.Context) context.Context { + if parent == nil { + parent = context.TODO() + } + traceId := uuid.NewString() + ctx := context.WithValue(parent, constants.TraceIdKey, traceId) + + traceLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ctx = context.WithValue(ctx, constants.TraceLoggerKey, traceLogger) + return ctx +} + +func NewMockSessionContextNoParent() context.Context { + return NewMockSessionContext(context.TODO()) +} diff --git a/clients/ui/bff/internal/repositories/model_registry.go b/clients/ui/bff/internal/repositories/model_registry.go index db4175952..60ec9bbb0 100644 --- a/clients/ui/bff/internal/repositories/model_registry.go +++ b/clients/ui/bff/internal/repositories/model_registry.go @@ -1,6 +1,7 @@ package repositories import ( + "context" "fmt" k8s "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/models" @@ -13,9 +14,9 @@ func NewModelRegistryRepository() *ModelRegistryRepository { return &ModelRegistryRepository{} } -func (m *ModelRegistryRepository) GetAllModelRegistries(client k8s.KubernetesClientInterface, namespace string) ([]models.ModelRegistryModel, error) { +func (m *ModelRegistryRepository) GetAllModelRegistries(sessionCtx context.Context, client k8s.KubernetesClientInterface, namespace string) ([]models.ModelRegistryModel, error) { - resources, err := client.GetServiceDetails(namespace) + resources, err := client.GetServiceDetails(sessionCtx, namespace) if err != nil { return nil, fmt.Errorf("error fetching model registries: %w", err) } diff --git a/clients/ui/bff/internal/repositories/model_registry_test.go b/clients/ui/bff/internal/repositories/model_registry_test.go index a5a0d903b..06ef12621 100644 --- a/clients/ui/bff/internal/repositories/model_registry_test.go +++ b/clients/ui/bff/internal/repositories/model_registry_test.go @@ -1,6 +1,7 @@ package repositories import ( + "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -13,7 +14,7 @@ var _ = Describe("TestFetchAllModelRegistry", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "kubeflow") + registries, err := modelRegistryRepository.GetAllModelRegistries(mocks.NewMockSessionContextNoParent(), k8sClient, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registries") @@ -28,7 +29,7 @@ var _ = Describe("TestFetchAllModelRegistry", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "dora-namespace") + registries, err := modelRegistryRepository.GetAllModelRegistries(mocks.NewMockSessionContextNoParent(), k8sClient, "dora-namespace") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registries") @@ -42,7 +43,7 @@ var _ = Describe("TestFetchAllModelRegistry", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "no-namespace") + registries, err := modelRegistryRepository.GetAllModelRegistries(mocks.NewMockSessionContextNoParent(), k8sClient, "no-namespace") Expect(err).NotTo(HaveOccurred()) By("should be empty")