From 5994df69752d8a69a131a7218aa0d65328cd58f0 Mon Sep 17 00:00:00 2001 From: Kaia Lang Date: Mon, 12 May 2025 10:20:04 -0700 Subject: [PATCH] feat(core): define module registry and lifecycle hooks --- router/core/hooks.go | 245 +++++++++++++++++++++++++++ router/core/modules_v1.go | 135 +++++++++++++++ router/core/modules_v1_test.go | 230 +++++++++++++++++++++++++ router/internal/utils/ordered_set.go | 61 +++++++ router/internal/utils/ptrs.go | 6 + 5 files changed, 677 insertions(+) create mode 100644 router/core/hooks.go create mode 100644 router/core/modules_v1.go create mode 100644 router/core/modules_v1_test.go create mode 100644 router/internal/utils/ordered_set.go create mode 100644 router/internal/utils/ptrs.go diff --git a/router/core/hooks.go b/router/core/hooks.go new file mode 100644 index 0000000000..d388832a04 --- /dev/null +++ b/router/core/hooks.go @@ -0,0 +1,245 @@ +package core + +import ( + "context" + + "github.com/wundergraph/cosmo/router/internal/utils" +) + +// Application Lifecycle Hooks +type ApplicationLifecycleHook interface { + ApplicationStartHook + ApplicationStopHook +} + +type ApplicationStartHook interface { + OnApplicationStart(ctx context.Context) +} + +type ApplicationStopHook interface { + OnApplicationStop(ctx context.Context) +} + +// GraphQL Server Lifecycle Hooks +type GraphQLServerLifecycleHook interface { + GraphQLServerStartHook + GraphQLServerStopHook +} + +type GraphQLServerStartHook interface { + OnGraphQLServerStart(ctx context.Context) +} + +type GraphQLServerStopHook interface { + OnGraphQLServerStop(ctx context.Context) +} + +// Router Lifecycle Hooks +type RouterRequestHook interface { + OnRouterRequest(ctx context.Context) +} + +type RouterResponseHook interface { + OnRouterResponse(ctx context.Context) +} + +type RouterLifecycleHook interface { + RouterRequestHook + RouterResponseHook +} + +// Subgraph Lifecycle Hooks +type SubgraphRequestHook interface { + OnSubgraphRequest(ctx context.Context) +} + +type SubgraphResponseHook interface { + OnSubgraphResponse(ctx context.Context) +} + +type SubgraphLifecycleHook interface { + SubgraphRequestHook + SubgraphResponseHook +} + +// Operation Lifecycle Hooks +type OperationLifecycleHook interface { + OperationParseLifecycleHook + OperationNormalizeLifecycleHook + OperationValidateLifecycleHook + OperationPlanLifecycleHook + OperationExecuteLifecycleHook +} + +type OperationParseLifecycleHook interface { + OperationPreParseHook + OperationPostParseHook +} + +type OperationPreParseHook interface { + OnOperationPreParse(ctx context.Context) +} + +type OperationPostParseHook interface { + OnOperationPostParse(ctx context.Context) +} + +type OperationNormalizeLifecycleHook interface { + OperationPreNormalizeHook + OperationPostNormalizeHook +} + +type OperationPreNormalizeHook interface { + OnOperationPreNormalize(ctx context.Context) +} + +type OperationPostNormalizeHook interface { + OnOperationPostNormalize(ctx context.Context) +} + +type OperationValidateLifecycleHook interface { + OperationPreValidateHook + OperationPostValidateHook +} + +type OperationPreValidateHook interface { + OnOperationPreValidate(ctx context.Context) +} + +type OperationPostValidateHook interface { + OnOperationPostValidate(ctx context.Context) +} + +type OperationPlanLifecycleHook interface { + OperationPrePlanHook + OperationPostPlanHook +} + +type OperationPrePlanHook interface { + OnOperationPrePlan(ctx context.Context) +} + +type OperationPostPlanHook interface { + OnOperationPostPlan(ctx context.Context) +} + +type OperationExecuteLifecycleHook interface { + OperationPreExecuteHook + OperationPostExecuteHook +} + +type OperationPreExecuteHook interface { + OnOperationPreExecute(ctx context.Context) +} + +type OperationPostExecuteHook interface { + OnOperationPostExecute(ctx context.Context) +} + +// hookRegistry holds the list of hooks for each type. +type hookRegistry struct { + applicationStartHooks *utils.OrderedSet[ApplicationStartHook] + applicationStopHooks *utils.OrderedSet[ApplicationStopHook] + + graphQLServerStartHooks *utils.OrderedSet[GraphQLServerStartHook] + graphQLServerStopHooks *utils.OrderedSet[GraphQLServerStopHook] + + routerRequestHooks *utils.OrderedSet[RouterRequestHook] + routerResponseHooks *utils.OrderedSet[RouterResponseHook] + + subgraphRequestHooks *utils.OrderedSet[SubgraphRequestHook] + subgraphResponseHooks *utils.OrderedSet[SubgraphResponseHook] + + operationPreParseHooks *utils.OrderedSet[OperationPreParseHook] + operationPostParseHooks *utils.OrderedSet[OperationPostParseHook] + + operationPreNormalizeHooks *utils.OrderedSet[OperationPreNormalizeHook] + operationPostNormalizeHooks *utils.OrderedSet[OperationPostNormalizeHook] + + operationPreValidateHooks *utils.OrderedSet[OperationPreValidateHook] + operationPostValidateHooks *utils.OrderedSet[OperationPostValidateHook] + + operationPrePlanHooks *utils.OrderedSet[OperationPrePlanHook] + operationPostPlanHooks *utils.OrderedSet[OperationPostPlanHook] + + operationPreExecuteHooks *utils.OrderedSet[OperationPreExecuteHook] + operationPostExecuteHooks *utils.OrderedSet[OperationPostExecuteHook] +} + +// newHookRegistry initializes with empty sets. +func newHookRegistry() *hookRegistry { + return &hookRegistry{ + applicationStartHooks: utils.NewOrderedSet[ApplicationStartHook](), + applicationStopHooks: utils.NewOrderedSet[ApplicationStopHook](), + + graphQLServerStartHooks: utils.NewOrderedSet[GraphQLServerStartHook](), + graphQLServerStopHooks: utils.NewOrderedSet[GraphQLServerStopHook](), + + routerRequestHooks: utils.NewOrderedSet[RouterRequestHook](), + routerResponseHooks: utils.NewOrderedSet[RouterResponseHook](), + + subgraphRequestHooks: utils.NewOrderedSet[SubgraphRequestHook](), + subgraphResponseHooks: utils.NewOrderedSet[SubgraphResponseHook](), + + operationPreParseHooks: utils.NewOrderedSet[OperationPreParseHook](), + operationPostParseHooks: utils.NewOrderedSet[OperationPostParseHook](), + + operationPreNormalizeHooks: utils.NewOrderedSet[OperationPreNormalizeHook](), + operationPostNormalizeHooks: utils.NewOrderedSet[OperationPostNormalizeHook](), + + operationPreValidateHooks: utils.NewOrderedSet[OperationPreValidateHook](), + operationPostValidateHooks: utils.NewOrderedSet[OperationPostValidateHook](), + + operationPrePlanHooks: utils.NewOrderedSet[OperationPrePlanHook](), + operationPostPlanHooks: utils.NewOrderedSet[OperationPostPlanHook](), + + operationPreExecuteHooks: utils.NewOrderedSet[OperationPreExecuteHook](), + operationPostExecuteHooks: utils.NewOrderedSet[OperationPostExecuteHook](), + } +} + +// registerHook is a helper to add any hook type if implemented. +func registerHook[H comparable](inst any, set *utils.OrderedSet[H]) { + if h, ok := inst.(H); ok { + set.Add(h) + } +} + +// AddApplicationLifecycle wires up start/stop hooks. +func (hr *hookRegistry) AddApplicationLifecycle(inst any) { + registerHook(inst, hr.applicationStartHooks) + registerHook(inst, hr.applicationStopHooks) +} + +// AddGraphQLServerLifecycle wires up GraphQL server start/stop hooks. +func (hr *hookRegistry) AddGraphQLServerLifecycle(inst any) { + registerHook(inst, hr.graphQLServerStartHooks) + registerHook(inst, hr.graphQLServerStopHooks) +} + +// AddRouterLifecycle wires up router request/response hooks. +func (hr *hookRegistry) AddRouterLifecycle(inst any) { + registerHook(inst, hr.routerRequestHooks) + registerHook(inst, hr.routerResponseHooks) +} + +// AddSubgraphLifecycle wires up subgraph request/response hooks. +func (hr *hookRegistry) AddSubgraphLifecycle(inst any) { + registerHook(inst, hr.subgraphRequestHooks) + registerHook(inst, hr.subgraphResponseHooks) +} + +// AddOperationLifecycle wires up all operation lifecycle hooks. +func (hr *hookRegistry) AddOperationLifecycle(inst any) { + registerHook(inst, hr.operationPreParseHooks) + registerHook(inst, hr.operationPostParseHooks) + registerHook(inst, hr.operationPreNormalizeHooks) + registerHook(inst, hr.operationPostNormalizeHooks) + registerHook(inst, hr.operationPreValidateHooks) + registerHook(inst, hr.operationPostValidateHooks) + registerHook(inst, hr.operationPrePlanHooks) + registerHook(inst, hr.operationPostPlanHooks) + registerHook(inst, hr.operationPreExecuteHooks) + registerHook(inst, hr.operationPostExecuteHooks) +} + diff --git a/router/core/modules_v1.go b/router/core/modules_v1.go new file mode 100644 index 0000000000..dd6ee7f6cb --- /dev/null +++ b/router/core/modules_v1.go @@ -0,0 +1,135 @@ +package core + +import ( + "fmt" + "math" + "sort" + "sync" + "context" + "go.uber.org/zap" +) + +type moduleRegistry struct { + mu sync.RWMutex + modules map[string]MyModuleInfo +} +// NewModuleRegistry returns an empty, thread-safe module registry. +// Call this in tests (and anywhere you need isolation) instead of using the global. +func newModuleRegistry() *moduleRegistry { + return &moduleRegistry{ + modules: make(map[string]MyModuleInfo), + } +} + +// TODO: @kaialang discuss if we should push for dependency injection. +// defaultModuleRegistry is the package-level registry used by RegisterMyModule. +// For unit tests you should use newModuleRegistry() to get a fresh instance and avoid shared state. +var defaultModuleRegistry = newModuleRegistry() + + +type MyModuleInfo struct { + // ID is the unique identifier for a module, it must be unique across all modules. + ID string + // Priority decideds the order of execution of the module. + // The smaller the number, the higher the priority, the earlier the module is executed. + // For example, a priority of 0 is the highest priority. + // Modules with the same priority are executed in the order they are registered. + // If Priority is nil, the module is considered to have the lowest priority. + Priority *int + // New creates a new instance of the module. + New func() MyModule +} + +type MyModule interface { + MyModule() MyModuleInfo +} + +// RegisterMyModule registers a new MyModule instance. +// The registration order matters. Modules with the same priority +// are executed in the order they are registered. +// It panics if the module is already registered. +func RegisterMyModule(instance MyModule) { + defaultModuleRegistry.registerMyModule(instance) +} + +func (r *moduleRegistry) registerMyModule(instance MyModule) { + m := instance.MyModule() + + if m.ID == "" { + panic("MyModule.ID is required") + } + if val := m.New(); val == nil { + panic("MyModuleInfo.New must return a non-nil module instance") + } + + r.mu.Lock() + defer r.mu.Unlock() + + if _, ok := r.modules[m.ID]; ok { + panic(fmt.Sprintf("MyModule already registered: %s", m.ID)) + } + r.modules[m.ID] = m +} + +// sortMyModules sorts the modules by priority, 0 is the highest priority, is the first to be executed. +// If two modules have the same priority, they are sorted by registration order. +// If a module has no priority, it is considered to have the lowest priority. +func sortMyModules(modules []MyModuleInfo) []MyModuleInfo { + sort.Slice(modules, func(i, j int) bool { + var priorityI, priorityJ int = math.MaxInt, math.MaxInt + if modules[i].Priority != nil { + priorityI = *modules[i].Priority + } + if modules[j].Priority != nil { + priorityJ = *modules[j].Priority + } + + return priorityI < priorityJ + }) + return modules +} + +// getMyModules returns all registered modules sorted by priority +func (r *moduleRegistry) getMyModules() []MyModuleInfo { + r.mu.RLock() + defer r.mu.RUnlock() + + modules := make([]MyModuleInfo, 0, len(r.modules)) + for _, m := range r.modules { + modules = append(modules, m) + } + return sortMyModules(modules) +} + +// coreModules manages module initialization and hook registration. +type coreModules struct { + hookRegistry *hookRegistry + logger *zap.Logger +} + +// newCoreModules initializes with an empty registry. +func newCoreModules(logger *zap.Logger) *coreModules { + return &coreModules{ + hookRegistry: newHookRegistry(), + logger: logger, + } +} + +// initMyModules instantiates each module, registers any implemented hooks, and saves the hook registry. +func (c *coreModules) initMyModules(ctx context.Context, modules []MyModuleInfo) error { + hookRegistry := newHookRegistry() + + for _, info := range modules { + moduleInstance := info.New() + + hookRegistry.AddApplicationLifecycle(moduleInstance) + hookRegistry.AddGraphQLServerLifecycle(moduleInstance) + hookRegistry.AddRouterLifecycle(moduleInstance) + hookRegistry.AddSubgraphLifecycle(moduleInstance) + hookRegistry.AddOperationLifecycle(moduleInstance) + } + + c.hookRegistry = hookRegistry + + return nil +} diff --git a/router/core/modules_v1_test.go b/router/core/modules_v1_test.go new file mode 100644 index 0000000000..56a1535aed --- /dev/null +++ b/router/core/modules_v1_test.go @@ -0,0 +1,230 @@ +package core + +import ( + "testing" + "context" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/utils" + + "go.uber.org/zap/zaptest" +) + + +type testModule1 struct {} + +func (m *testModule1) MyModule() MyModuleInfo { + return MyModuleInfo{ + ID: "testModule1", + Priority: utils.Ptr(0), + New: func() MyModule { + return &testModule1{} + }, + } +} + +func (m *testModule1) OnApplicationStart(ctx context.Context) {} + +type testModule2 struct {} + +func (m *testModule2) MyModule() MyModuleInfo { + return MyModuleInfo{ + ID: "testModule2", + Priority: utils.Ptr(1), + New: func() MyModule { + return &testModule2{} + }, + } +} + +func (m *testModule2) OnApplicationStart(ctx context.Context) {} +func (m *testModule2) OnApplicationStop(ctx context.Context) {} + +type testModule3 struct {} + +func (m *testModule3) MyModule() MyModuleInfo { + return MyModuleInfo{ + Priority: utils.Ptr(1), + New: func() MyModule { + return &testModule3{} + }, + } +} + +func (m *testModule3) OnApplicationStart(ctx context.Context) {} +func (m *testModule3) OnApplicationStop(ctx context.Context) {} + + +type testModule4 struct {} + +func (m *testModule4) MyModule() MyModuleInfo { + return MyModuleInfo{ + ID: "testModule4", + Priority: utils.Ptr(1), + } +} + +// interface guards +var _ ApplicationStartHook = (*testModule1)(nil) + +// registers the applicationStartHook only once +var _ ApplicationStartHook = (*testModule2)(nil) +var _ ApplicationLifecycleHook = (*testModule2)(nil) + +var _ ApplicationStartHook = (*testModule3)(nil) +var _ ApplicationStopHook = (*testModule3)(nil) + +func TestRegisterMyModule(t *testing.T) { + t.Parallel() + + m1 := &testModule1{} + m2 := &testModule2{} + m3 := &testModule3{} + m4 := &testModule4{} + m5 := &testModule1{} + t.Run("success", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + testModuleRegistry.registerMyModule(m1) + testModuleRegistry.registerMyModule(m2) + + require.Equal(t, "testModule1", testModuleRegistry.modules["testModule1"].ID) + require.Equal(t, "testModule2", testModuleRegistry.modules["testModule2"].ID) + }) + + t.Run("panic_if_module_id_is_empty", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + require.Panics(t, func() { + testModuleRegistry.registerMyModule(m3) + }) + }) + + t.Run("panic_if_module_new_returns_nil", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + require.Panics(t, func() { + testModuleRegistry.registerMyModule(m4) + }) + }) + + t.Run("panic_if_module_id_is_not_unique", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + require.Panics(t, func() { + testModuleRegistry.registerMyModule(m1) + testModuleRegistry.registerMyModule(m5) + }) + }) +} + +func TestSortMyModules(t *testing.T) { + t.Parallel() + + module0 := MyModuleInfo{ + ID: "module0", + Priority: utils.Ptr(0), + } + + module1 := MyModuleInfo{ + ID: "module1", + Priority: utils.Ptr(1), + } + + module2 := MyModuleInfo{ + ID: "module2", + Priority: utils.Ptr(2), + } + + module3 := MyModuleInfo{ + ID: "module3", + Priority: utils.Ptr(0), + } + + moduleNilPriority := MyModuleInfo{ + ID: "moduleNil", + } + + t.Run("success", func(t *testing.T) { + modules := []MyModuleInfo{ + moduleNilPriority, + module2, + module0, + module1, + } + result := sortMyModules(modules) + + expected := []MyModuleInfo{ + module0, + module1, + module2, + moduleNilPriority, + } + + require.EqualValues(t, expected, result) + }) + + t.Run("same_priority", func(t *testing.T) { + modules := []MyModuleInfo{ + module3, + module0, + } + result := sortMyModules(modules) + + expected := []MyModuleInfo{ + module3, + module0, + } + + require.EqualValues(t, expected, result) + }) + + t.Run("no_modules_not_panic", func(t *testing.T) { + modules := []MyModuleInfo{} + require.Equal(t, []MyModuleInfo{}, sortMyModules(modules)) + }) +} + +func TestInitMyModules(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + modules := []MyModuleInfo{ + { + ID: "testModule1", + New: func() MyModule { + return &testModule1{} + }, + }, + { + ID: "testModule3", + New: func() MyModule { + return &testModule3{} + }, + }, + } + cm := newCoreModules(zaptest.NewLogger(t)) + err := cm.initMyModules(context.Background(), modules) + require.NoError(t, err) + + require.Equal(t, 2, len(cm.hookRegistry.applicationStartHooks.Values())) + require.Equal(t, 1, len(cm.hookRegistry.applicationStopHooks.Values())) + }) + + t.Run("success_deduplicate_hooks", func(t *testing.T) { + modules := []MyModuleInfo{ + { + ID: "testModule2", + New: func() MyModule { + return &testModule2{} + }, + }, + } + cm := newCoreModules(zaptest.NewLogger(t)) + err := cm.initMyModules(context.Background(), modules) + require.NoError(t, err) + + require.Equal(t, 1, len(cm.hookRegistry.applicationStartHooks.Values())) + require.Equal(t, 1, len(cm.hookRegistry.applicationStopHooks.Values())) + }) +} diff --git a/router/internal/utils/ordered_set.go b/router/internal/utils/ordered_set.go new file mode 100644 index 0000000000..aab895b0b5 --- /dev/null +++ b/router/internal/utils/ordered_set.go @@ -0,0 +1,61 @@ +package utils + +type OrderedSet[T comparable] struct { + elements []T + index map[T]struct{} +} + +// NewOrderedSet creates and returns a new OrderedSet. +func NewOrderedSet[T comparable]() *OrderedSet[T] { + return &OrderedSet[T]{ + elements: make([]T, 0), + index: make(map[T]struct{}), + } +} + +// Add inserts elem into the set if it's not already present. +func (s *OrderedSet[T]) Add(elem T) { + if _, exists := s.index[elem]; !exists { + s.index[elem] = struct{}{} + s.elements = append(s.elements, elem) + } +} + +// Remove deletes elem from the set if it exists, preserving order of other elements. +func (s *OrderedSet[T]) Remove(elem T) { + if _, exists := s.index[elem]; exists { + delete(s.index, elem) + // rebuild slice without the removed element + for i, v := range s.elements { + if v == elem { + s.elements = append(s.elements[:i], s.elements[i+1:]...) + break + } + } + } +} + +// Contains returns true if elem is in the set. +func (s *OrderedSet[T]) Contains(elem T) bool { + _, exists := s.index[elem] + return exists +} + +// Values returns a slice of elements in insertion order. +// The returned slice is a copy; modifying it won't affect the set. +func (s *OrderedSet[T]) Values() []T { + dup := make([]T, len(s.elements)) + copy(dup, s.elements) + return dup +} + +// Len returns the number of elements in the set. +func (s *OrderedSet[T]) Len() int { + return len(s.elements) +} + +// Clear removes all elements from the set. +func (s *OrderedSet[T]) Clear() { + s.elements = make([]T, 0) + s.index = make(map[T]struct{}) +} \ No newline at end of file diff --git a/router/internal/utils/ptrs.go b/router/internal/utils/ptrs.go new file mode 100644 index 0000000000..c954c0b883 --- /dev/null +++ b/router/internal/utils/ptrs.go @@ -0,0 +1,6 @@ +package utils + +func Ptr[T any](v T) *T { + return &v +} +