From c8586151fbb1b8b988a829d11f90ffc750e46916 Mon Sep 17 00:00:00 2001 From: Dhiraj Bokde Date: Wed, 22 Jan 2025 12:47:47 -0800 Subject: [PATCH] add support to disable catalog or registry service endpoints Signed-off-by: Dhiraj Bokde --- cmd/proxy.go | 86 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 32 deletions(-) diff --git a/cmd/proxy.go b/cmd/proxy.go index 6b49ecbd..c8c96094 100644 --- a/cmd/proxy.go +++ b/cmd/proxy.go @@ -32,40 +32,53 @@ func runProxyServer(cmd *cobra.Command, args []string) error { ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort) - glog.Infof("connecting to MLMD server %s..", mlmdAddr) - conn, err := grpc.DialContext( // nolint:staticcheck - ctxTimeout, - mlmdAddr, - grpc.WithReturnConnectionError(), // nolint:staticcheck - grpc.WithBlock(), // nolint:staticcheck - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - return fmt.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err) + routers := make([]openapi.Router, 0) + + disableService := proxyCfg.DisableService + if disableService != CatalogService { + + // TODO read yaml catalog file and instantiate ModelCatalogAPI implementations + ModelCatalogServiceAPIService := openapi.NewModelCatalogServiceAPIService(map[string]openapi.ModelCatalogApi{}) + ModelCatalogServiceAPIController := openapi.NewModelCatalogServiceAPIController(ModelCatalogServiceAPIService) + routers = append(routers, ModelCatalogServiceAPIController) + + } else if disableService != RegistryService { + + mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort) + glog.Infof("connecting to MLMD server %s..", mlmdAddr) + conn, err := grpc.DialContext( // nolint:staticcheck + ctxTimeout, + mlmdAddr, + grpc.WithReturnConnectionError(), // nolint:staticcheck + grpc.WithBlock(), // nolint:staticcheck + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return fmt.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err) + } + defer conn.Close() + glog.Infof("connected to MLMD server") + + mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() + _, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig) + if err != nil { + return fmt.Errorf("error creating MLMD types: %v", err) + } + service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) + if err != nil { + return fmt.Errorf("error creating core service: %v", err) + } + + // TODO make registry API optional to support standalone Catalog deployments + ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service) + ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService) + routers = append(routers, ModelRegistryServiceAPIController) + + } else if disableService != "" { + return fmt.Errorf("invalid disable-service: %v", disableService) } - defer conn.Close() - glog.Infof("connected to MLMD server") - mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults() - _, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig) - if err != nil { - return fmt.Errorf("error creating MLMD types: %v", err) - } - service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig) - if err != nil { - return fmt.Errorf("error creating core service: %v", err) - } - - // TODO read yaml catalog file and instantiate ModelCatalogAPI implementations - ModelCatalogServiceAPIService := openapi.NewModelCatalogServiceAPIService(map[string]openapi.ModelCatalogApi{}) - ModelCatalogServiceAPIController := openapi.NewModelCatalogServiceAPIController(ModelCatalogServiceAPIService) - - // TODO make registry API optional to support standalone Catalog deployments - ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service) - ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService) - - router := openapi.NewRouter(ModelRegistryServiceAPIController, ModelCatalogServiceAPIController) + router := openapi.NewRouter(routers...) glog.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)) return nil @@ -79,11 +92,20 @@ func init() { proxyCmd.Flags().StringVar(&proxyCfg.MLMDHostname, "mlmd-hostname", proxyCfg.MLMDHostname, "MLMD hostname") proxyCmd.Flags().IntVar(&proxyCfg.MLMDPort, "mlmd-port", proxyCfg.MLMDPort, "MLMD port") + + proxyCmd.Flags().StringVar(&proxyCfg.DisableService, "disable-service", proxyCfg.DisableService, "Name of service/endpoint to disable, can be either \"catalog\" or \"registry\"") } +const ( + CatalogService string = "catalog" + RegistryService string = "registry" +) + type ProxyConfig struct { MLMDHostname string MLMDPort int + + DisableService string } var proxyCfg = ProxyConfig{