From d4204d5c3e5f39811963378be691cbdc576d267a Mon Sep 17 00:00:00 2001 From: Theaux Masquelier <43664045+Theauxm@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:48:30 -0600 Subject: [PATCH 1/2] feat: add per-train lifecycle hook overrides on ServiceTrain Adds protected virtual OnStarted, OnCompleted, OnFailed, and OnCancelled methods to ServiceTrain. Individual trains can override these to react to their own lifecycle events without registering a global ITrainLifecycleHook. Global hooks fire first, then train-level overrides. Exceptions in overrides are caught and logged, matching the global hook error handling pattern. --- .../Services/ServiceTrain/ServiceTrain.cs | 117 +++++ .../TrainLifecycleOverrideTests.cs | 442 ++++++++++++++++++ 2 files changed, 559 insertions(+) create mode 100644 tests/Trax.Effect.Tests.Data.InMemory.Integration/IntegrationTests/TrainLifecycleOverrideTests.cs diff --git a/src/Trax.Effect/Services/ServiceTrain/ServiceTrain.cs b/src/Trax.Effect/Services/ServiceTrain/ServiceTrain.cs index 9fc2afb..c5b163e 100644 --- a/src/Trax.Effect/Services/ServiceTrain/ServiceTrain.cs +++ b/src/Trax.Effect/Services/ServiceTrain/ServiceTrain.cs @@ -85,6 +85,37 @@ public abstract class ServiceTrain : Train, IServiceTrain< ?? GetType().FullName ?? throw new TrainException($"Could not find FullName for ({GetType().Name})"); + /// + /// Called after the train's metadata is initialized and persisted, before RunInternal executes. + /// Override to add per-train startup logic. Exceptions are caught and logged — they will not + /// prevent the train from running. + /// + protected virtual Task OnStarted(Metadata metadata, CancellationToken ct) => Task.CompletedTask; + + /// + /// Called after a successful run, after output is persisted and global hooks have fired. + /// Override to add per-train completion logic (e.g., notifications, cache invalidation). + /// Exceptions are caught and logged — they will not cause the train to report failure. + /// + protected virtual Task OnCompleted(Metadata metadata, CancellationToken ct) => + Task.CompletedTask; + + /// + /// Called after a failed run (non-cancellation exception), after failure state is persisted + /// and global hooks have fired. Override to add per-train failure handling (e.g., alerting). + /// Exceptions are caught and logged — they will not mask the original failure. + /// + protected virtual Task OnFailed(Metadata metadata, Exception exception, CancellationToken ct) => + Task.CompletedTask; + + /// + /// Called after cancellation (OperationCanceledException), after cancellation state is persisted + /// and global hooks have fired. Override to add per-train cancellation handling. + /// Exceptions are caught and logged. + /// + protected virtual Task OnCancelled(Metadata metadata, CancellationToken ct) => + Task.CompletedTask; + /// /// Overrides the base Train Run method to add database tracking and logging capabilities. /// @@ -108,6 +139,19 @@ public override async Task Run(TIn input, CancellationToken cancellationTo await LifecycleHookRunner.OnStarted(Metadata, CancellationToken); + try + { + await OnStarted(Metadata, CancellationToken); + } + catch (Exception hookEx) + { + Logger?.LogError( + hookEx, + "Train-level OnStarted hook threw for train ({TrainName}).", + TrainName + ); + } + try { Logger?.LogTrace("Running Train: ({TrainName})", TrainName); @@ -126,10 +170,40 @@ public override async Task Run(TIn input, CancellationToken cancellationTo await EffectRunner.SaveChanges(CancellationToken); if (exception is OperationCanceledException) + { await LifecycleHookRunner.OnCancelled(Metadata, CancellationToken); + + try + { + await OnCancelled(Metadata, CancellationToken); + } + catch (Exception hookEx) + { + Logger?.LogError( + hookEx, + "Train-level OnCancelled hook threw for train ({TrainName}).", + TrainName + ); + } + } else + { await LifecycleHookRunner.OnFailed(Metadata, exception, CancellationToken); + try + { + await OnFailed(Metadata, exception, CancellationToken); + } + catch (Exception hookEx) + { + Logger?.LogError( + hookEx, + "Train-level OnFailed hook threw for train ({TrainName}).", + TrainName + ); + } + } + exception.Rethrow(); } @@ -143,6 +217,19 @@ public override async Task Run(TIn input, CancellationToken cancellationTo await LifecycleHookRunner.OnCompleted(Metadata, CancellationToken); + try + { + await OnCompleted(Metadata, CancellationToken); + } + catch (Exception hookEx) + { + Logger?.LogError( + hookEx, + "Train-level OnCompleted hook threw for train ({TrainName}).", + TrainName + ); + } + return output; } catch (Exception e) @@ -157,10 +244,40 @@ public override async Task Run(TIn input, CancellationToken cancellationTo await EffectRunner.SaveChanges(CancellationToken); if (e is OperationCanceledException) + { await LifecycleHookRunner.OnCancelled(Metadata, CancellationToken); + + try + { + await OnCancelled(Metadata, CancellationToken); + } + catch (Exception hookEx) + { + Logger?.LogError( + hookEx, + "Train-level OnCancelled hook threw for train ({TrainName}).", + TrainName + ); + } + } else + { await LifecycleHookRunner.OnFailed(Metadata, e, CancellationToken); + try + { + await OnFailed(Metadata, e, CancellationToken); + } + catch (Exception hookEx) + { + Logger?.LogError( + hookEx, + "Train-level OnFailed hook threw for train ({TrainName}).", + TrainName + ); + } + } + throw; } } diff --git a/tests/Trax.Effect.Tests.Data.InMemory.Integration/IntegrationTests/TrainLifecycleOverrideTests.cs b/tests/Trax.Effect.Tests.Data.InMemory.Integration/IntegrationTests/TrainLifecycleOverrideTests.cs new file mode 100644 index 0000000..74ece47 --- /dev/null +++ b/tests/Trax.Effect.Tests.Data.InMemory.Integration/IntegrationTests/TrainLifecycleOverrideTests.cs @@ -0,0 +1,442 @@ +using FluentAssertions; +using LanguageExt; +using Microsoft.Extensions.DependencyInjection; +using Trax.Core.Exceptions; +using Trax.Effect.Extensions; +using Trax.Effect.Models.Metadata; +using Trax.Effect.Services.ServiceTrain; +using Trax.Effect.Tests.Data.InMemory.Integration.Fixtures; + +namespace Trax.Effect.Tests.Data.InMemory.Integration.IntegrationTests; + +public class TrainLifecycleOverrideTests : TestSetup +{ + public override ServiceProvider ConfigureServices(IServiceCollection services) => + services + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .AddScopedTraxRoute() + .BuildServiceProvider(); + + #region OnStarted + + [Test] + public async Task Run_SuccessfulTrain_CallsOnStarted() + { + var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService(); + + await train.Run(Unit.Default); + + train.StartedCalled.Should().BeTrue(); + train.StartedMetadata.Should().NotBeNull(); + train.StartedMetadata!.Name.Should().Be(typeof(IRecordingTrain).FullName); + } + + [Test] + public async Task Run_FailingTrain_StillCallsOnStarted() + { + var train = (FailingRecordingTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + await act.Should().ThrowAsync(); + + train.StartedCalled.Should().BeTrue(); + } + + [Test] + public async Task Run_OnStartedThrows_TrainStillRuns() + { + var train = (ThrowingHookTrain) + Scope.ServiceProvider.GetRequiredService(); + + var result = await train.Run(Unit.Default); + + result.Should().Be(Unit.Default); + train.Metadata!.TrainState.Should().Be(Enums.TrainState.Completed); + } + + #endregion + + #region OnCompleted + + [Test] + public async Task Run_SuccessfulTrain_CallsOnCompleted() + { + var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService(); + + await train.Run(Unit.Default); + + train.CompletedCalled.Should().BeTrue(); + train.CompletedMetadata.Should().NotBeNull(); + } + + [Test] + public async Task Run_FailingTrain_DoesNotCallOnCompleted() + { + var train = (FailingRecordingTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + await act.Should().ThrowAsync(); + + train.CompletedCalled.Should().BeFalse(); + } + + [Test] + public async Task Run_OnCompletedThrows_DoesNotCauseFailure() + { + var train = (ThrowingHookTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + + await act.Should().NotThrowAsync(); + } + + #endregion + + #region OnFailed + + [Test] + public async Task Run_FailingTrain_CallsOnFailedWithException() + { + var train = (FailingRecordingTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + await act.Should().ThrowAsync(); + + train.FailedCalled.Should().BeTrue(); + train.FailedException.Should().NotBeNull(); + train.FailedException.Should().BeOfType(); + train.FailedMetadata.Should().NotBeNull(); + } + + [Test] + public async Task Run_SuccessfulTrain_DoesNotCallOnFailed() + { + var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService(); + + await train.Run(Unit.Default); + + train.FailedCalled.Should().BeFalse(); + } + + [Test] + public async Task Run_OnFailedThrows_OriginalExceptionStillPropagates() + { + var train = (ThrowingOnFailedHookTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + + await act.Should().ThrowAsync().WithMessage("*Intentional train failure*"); + } + + #endregion + + #region OnCancelled + + [Test] + public async Task Run_CancelledTrain_CallsOnCancelled() + { + var train = (CancellingRecordingTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + await act.Should().ThrowAsync(); + + train.CancelledCalled.Should().BeTrue(); + train.CancelledMetadata.Should().NotBeNull(); + } + + [Test] + public async Task Run_FailingTrain_NonCancellation_DoesNotCallOnCancelled() + { + var train = (FailingRecordingTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + await act.Should().ThrowAsync(); + + train.CancelledCalled.Should().BeFalse(); + } + + [Test] + public async Task Run_OnCancelledThrows_CancellationStillPropagates() + { + var train = (ThrowingOnCancelledHookTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + + await act.Should().ThrowAsync(); + } + + #endregion + + #region Ordering & Defaults + + [Test] + public async Task Run_SuccessfulTrain_OnStartedBeforeOnCompleted() + { + var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService(); + + await train.Run(Unit.Default); + + train.CallOrder.Should().ContainInOrder("OnStarted", "OnCompleted"); + } + + [Test] + public async Task Run_FailingTrain_OnStartedBeforeOnFailed() + { + var train = (FailingRecordingTrain) + Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + await act.Should().ThrowAsync(); + + train.CallOrder.Should().ContainInOrder("OnStarted", "OnFailed"); + } + + [Test] + public async Task Run_NoOverrides_DefaultsDoNotThrow() + { + var train = Scope.ServiceProvider.GetRequiredService(); + + var act = async () => await train.Run(Unit.Default); + + await act.Should().NotThrowAsync(); + } + + [Test] + public async Task Run_PartialOverride_OnlyOverriddenMethodCalled() + { + var train = (PartialOverrideTrain) + Scope.ServiceProvider.GetRequiredService(); + + await train.Run(Unit.Default); + + train.CompletedCalled.Should().BeTrue(); + train.CallOrder.Should().ContainSingle().Which.Should().Be("OnCompleted"); + } + + [Test] + public async Task Run_CancellationTokenPassedToHooks() + { + var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService(); + using var cts = new CancellationTokenSource(); + + await train.Run(Unit.Default, cts.Token); + + train.StartedCancellationToken.Should().Be(cts.Token); + train.CompletedCancellationToken.Should().Be(cts.Token); + } + + #endregion + + #region Test Trains + + private interface IRecordingTrain : IServiceTrain { } + + private class RecordingTrain : ServiceTrain, IRecordingTrain + { + public bool StartedCalled { get; private set; } + public bool CompletedCalled { get; private set; } + public bool FailedCalled { get; private set; } + public bool CancelledCalled { get; private set; } + public Metadata? StartedMetadata { get; private set; } + public Metadata? CompletedMetadata { get; private set; } + public Metadata? FailedMetadata { get; private set; } + public Exception? FailedException { get; private set; } + public Metadata? CancelledMetadata { get; private set; } + public CancellationToken StartedCancellationToken { get; private set; } + public CancellationToken CompletedCancellationToken { get; private set; } + public List CallOrder { get; } = []; + + protected override async Task> RunInternal(Unit input) => + Activate(input).Resolve(); + + protected override Task OnStarted(Metadata metadata, CancellationToken ct) + { + StartedCalled = true; + StartedMetadata = metadata; + StartedCancellationToken = ct; + CallOrder.Add("OnStarted"); + return Task.CompletedTask; + } + + protected override Task OnCompleted(Metadata metadata, CancellationToken ct) + { + CompletedCalled = true; + CompletedMetadata = metadata; + CompletedCancellationToken = ct; + CallOrder.Add("OnCompleted"); + return Task.CompletedTask; + } + + protected override Task OnFailed( + Metadata metadata, + Exception exception, + CancellationToken ct + ) + { + FailedCalled = true; + FailedMetadata = metadata; + FailedException = exception; + CallOrder.Add("OnFailed"); + return Task.CompletedTask; + } + + protected override Task OnCancelled(Metadata metadata, CancellationToken ct) + { + CancelledCalled = true; + CancelledMetadata = metadata; + CallOrder.Add("OnCancelled"); + return Task.CompletedTask; + } + } + + private interface IFailingRecordingTrain : IServiceTrain { } + + private class FailingRecordingTrain : ServiceTrain, IFailingRecordingTrain + { + public bool StartedCalled { get; private set; } + public bool CompletedCalled { get; private set; } + public bool FailedCalled { get; private set; } + public bool CancelledCalled { get; private set; } + public Metadata? FailedMetadata { get; private set; } + public Exception? FailedException { get; private set; } + public List CallOrder { get; } = []; + + protected override async Task> RunInternal(Unit input) => + new TrainException("Intentional train failure"); + + protected override Task OnStarted(Metadata metadata, CancellationToken ct) + { + StartedCalled = true; + CallOrder.Add("OnStarted"); + return Task.CompletedTask; + } + + protected override Task OnCompleted(Metadata metadata, CancellationToken ct) + { + CompletedCalled = true; + CallOrder.Add("OnCompleted"); + return Task.CompletedTask; + } + + protected override Task OnFailed( + Metadata metadata, + Exception exception, + CancellationToken ct + ) + { + FailedCalled = true; + FailedMetadata = metadata; + FailedException = exception; + CallOrder.Add("OnFailed"); + return Task.CompletedTask; + } + + protected override Task OnCancelled(Metadata metadata, CancellationToken ct) + { + CancelledCalled = true; + CallOrder.Add("OnCancelled"); + return Task.CompletedTask; + } + } + + private interface ICancellingRecordingTrain : IServiceTrain { } + + private class CancellingRecordingTrain : ServiceTrain, ICancellingRecordingTrain + { + public bool CancelledCalled { get; private set; } + public Metadata? CancelledMetadata { get; private set; } + + protected override async Task> RunInternal(Unit input) => + throw new OperationCanceledException("Intentional cancellation"); + + protected override Task OnCancelled(Metadata metadata, CancellationToken ct) + { + CancelledCalled = true; + CancelledMetadata = metadata; + return Task.CompletedTask; + } + } + + private interface IThrowingHookTrain : IServiceTrain { } + + private class ThrowingHookTrain : ServiceTrain, IThrowingHookTrain + { + protected override async Task> RunInternal(Unit input) => + Activate(input).Resolve(); + + protected override Task OnStarted(Metadata metadata, CancellationToken ct) => + throw new InvalidOperationException("OnStarted hook failed"); + + protected override Task OnCompleted(Metadata metadata, CancellationToken ct) => + throw new InvalidOperationException("OnCompleted hook failed"); + } + + private interface IThrowingOnFailedHookTrain : IServiceTrain { } + + private class ThrowingOnFailedHookTrain : ServiceTrain, IThrowingOnFailedHookTrain + { + protected override async Task> RunInternal(Unit input) => + new TrainException("Intentional train failure"); + + protected override Task OnFailed( + Metadata metadata, + Exception exception, + CancellationToken ct + ) => throw new InvalidOperationException("OnFailed hook failed"); + } + + private interface IThrowingOnCancelledHookTrain : IServiceTrain { } + + private class ThrowingOnCancelledHookTrain + : ServiceTrain, + IThrowingOnCancelledHookTrain + { + protected override async Task> RunInternal(Unit input) => + throw new OperationCanceledException("Intentional cancellation"); + + protected override Task OnCancelled(Metadata metadata, CancellationToken ct) => + throw new InvalidOperationException("OnCancelled hook failed"); + } + + private interface IPartialOverrideTrain : IServiceTrain { } + + private class PartialOverrideTrain : ServiceTrain, IPartialOverrideTrain + { + public bool CompletedCalled { get; private set; } + public List CallOrder { get; } = []; + + protected override async Task> RunInternal(Unit input) => + Activate(input).Resolve(); + + protected override Task OnCompleted(Metadata metadata, CancellationToken ct) + { + CompletedCalled = true; + CallOrder.Add("OnCompleted"); + return Task.CompletedTask; + } + } + + private interface INoOverrideTrain : IServiceTrain { } + + private class NoOverrideTrain : ServiceTrain, INoOverrideTrain + { + protected override async Task> RunInternal(Unit input) => + Activate(input).Resolve(); + } + + #endregion +} From 2e173533abb3ea55c603e9fc37e5796624fb6db7 Mon Sep 17 00:00:00 2001 From: Theaux Masquelier <43664045+Theauxm@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:50:12 -0600 Subject: [PATCH 2/2] refactor: unify lifecycle hook registration with generic LifecycleHookFactory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AddLifecycleHook() now accepts either ITrainLifecycleHook or ITrainLifecycleHookFactory directly. When given a hook type, a generic LifecycleHookFactory is created internally — no need for hand-written factory classes. Removes BroadcastLifecycleHookFactory in favor of the generic factory. --- .../Extensions/BroadcasterExtensions.cs | 10 +- .../Extensions/ServiceExtensions.cs | 77 +++++++--- .../BroadcastLifecycleHookFactory.cs | 15 -- .../LifecycleHookFactory.cs | 17 ++ .../UnitTests/BroadcasterRegistrationTests.cs | 2 +- .../Services/AddLifecycleHookTests.cs | 145 ++++++++++++++++++ 6 files changed, 219 insertions(+), 47 deletions(-) delete mode 100644 src/Trax.Effect/Services/TrainEventBroadcaster/BroadcastLifecycleHookFactory.cs create mode 100644 src/Trax.Effect/Services/TrainLifecycleHookFactory/LifecycleHookFactory.cs diff --git a/src/Trax.Effect/Extensions/BroadcasterExtensions.cs b/src/Trax.Effect/Extensions/BroadcasterExtensions.cs index 3e03b29..d0bd6bd 100644 --- a/src/Trax.Effect/Extensions/BroadcasterExtensions.cs +++ b/src/Trax.Effect/Extensions/BroadcasterExtensions.cs @@ -2,7 +2,6 @@ using Trax.Effect.Configuration.BroadcasterBuilder; using Trax.Effect.Configuration.TraxEffectBuilder; using Trax.Effect.Services.TrainEventBroadcaster; -using Trax.Effect.Services.TrainLifecycleHookFactory; namespace Trax.Effect.Extensions; @@ -31,14 +30,7 @@ Action configure var broadcasterBuilder = new BroadcasterBuilder(builder); configure(broadcasterBuilder); - builder.ServiceCollection.AddTransient(); - builder - .ServiceCollection.AddSingleton() - .AddSingleton(sp => - sp.GetRequiredService() - ); - - builder.EffectRegistry?.Register(typeof(BroadcastLifecycleHookFactory), toggleable: false); + builder.AddLifecycleHook(toggleable: false); builder.ServiceCollection.AddHostedService(); diff --git a/src/Trax.Effect/Extensions/ServiceExtensions.cs b/src/Trax.Effect/Extensions/ServiceExtensions.cs index faf7b81..e9b4e20 100644 --- a/src/Trax.Effect/Extensions/ServiceExtensions.cs +++ b/src/Trax.Effect/Extensions/ServiceExtensions.cs @@ -10,6 +10,7 @@ using Trax.Effect.Services.JunctionEffectProviderFactory; using Trax.Effect.Services.JunctionEffectRunner; using Trax.Effect.Services.LifecycleHookRunner; +using Trax.Effect.Services.TrainLifecycleHook; using Trax.Effect.Services.TrainLifecycleHookFactory; namespace Trax.Effect.Extensions; @@ -392,55 +393,87 @@ public static TraxEffectBuilder AddJunctionEffect - /// Registers a train lifecycle hook with both its interface and implementation type, - /// using an existing factory instance. Lifecycle hooks run at train start/completion/failure boundaries. + /// Registers a train lifecycle hook or hook factory. + /// + /// If implements , a factory is + /// created internally — no need to write a separate factory class. The hook is resolved from DI + /// on each train execution, so constructor-injected dependencies work. + /// + /// + /// If implements , the factory + /// is registered directly (advanced usage). + /// /// - /// The lifecycle hook factory interface. - /// The concrete lifecycle hook factory type. + /// + /// Either a concrete type or an type. + /// /// The effect builder. - /// The factory instance to register. /// Whether this hook can be toggled on/off at runtime. Defaults to true. /// The effect builder for chaining. - public static TraxEffectBuilder AddLifecycleHook( + public static TraxEffectBuilder AddLifecycleHook( this TraxEffectBuilder builder, - TLifecycleHookFactory factory, bool toggleable = true ) - where TILifecycleHookFactory : class, ITrainLifecycleHookFactory - where TLifecycleHookFactory : class, TILifecycleHookFactory + where T : class { - builder - .ServiceCollection.AddSingleton(factory) - .AddSingleton(sp => - sp.GetRequiredService() - ) - .AddSingleton(sp => - sp.GetRequiredService() + if (typeof(ITrainLifecycleHookFactory).IsAssignableFrom(typeof(T))) + { + builder + .ServiceCollection.AddSingleton(typeof(T)) + .AddSingleton(sp => + (ITrainLifecycleHookFactory)sp.GetRequiredService(typeof(T)) + ); + + builder.EffectRegistry?.Register(typeof(T), toggleable: toggleable); + } + else if (typeof(ITrainLifecycleHook).IsAssignableFrom(typeof(T))) + { + builder.ServiceCollection.AddTransient(typeof(T)); + + var factoryType = typeof(LifecycleHookFactory<>).MakeGenericType(typeof(T)); + builder.ServiceCollection.AddSingleton(factoryType); + builder.ServiceCollection.AddSingleton(sp => + (ITrainLifecycleHookFactory)sp.GetRequiredService(factoryType) ); - builder.EffectRegistry?.Register(typeof(TLifecycleHookFactory), toggleable: toggleable); + builder.EffectRegistry?.Register(factoryType, toggleable: toggleable); + } + else + { + throw new InvalidOperationException( + $"AddLifecycleHook<{typeof(T).Name}>() requires a type that implements " + + $"ITrainLifecycleHook or ITrainLifecycleHookFactory." + ); + } return builder; } /// - /// Registers a train lifecycle hook resolved from DI. - /// Lifecycle hooks run at train start/completion/failure boundaries. + /// Registers a train lifecycle hook with both its interface and implementation type, + /// using an existing factory instance. Lifecycle hooks run at train start/completion/failure boundaries. /// + /// The lifecycle hook factory interface. /// The concrete lifecycle hook factory type. /// The effect builder. + /// The factory instance to register. /// Whether this hook can be toggled on/off at runtime. Defaults to true. /// The effect builder for chaining. - public static TraxEffectBuilder AddLifecycleHook( + public static TraxEffectBuilder AddLifecycleHook( this TraxEffectBuilder builder, + TLifecycleHookFactory factory, bool toggleable = true ) - where TLifecycleHookFactory : class, ITrainLifecycleHookFactory + where TILifecycleHookFactory : class, ITrainLifecycleHookFactory + where TLifecycleHookFactory : class, TILifecycleHookFactory { builder - .ServiceCollection.AddSingleton() + .ServiceCollection.AddSingleton(factory) .AddSingleton(sp => sp.GetRequiredService() + ) + .AddSingleton(sp => + sp.GetRequiredService() ); builder.EffectRegistry?.Register(typeof(TLifecycleHookFactory), toggleable: toggleable); diff --git a/src/Trax.Effect/Services/TrainEventBroadcaster/BroadcastLifecycleHookFactory.cs b/src/Trax.Effect/Services/TrainEventBroadcaster/BroadcastLifecycleHookFactory.cs deleted file mode 100644 index 4edf0bb..0000000 --- a/src/Trax.Effect/Services/TrainEventBroadcaster/BroadcastLifecycleHookFactory.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Microsoft.Extensions.DependencyInjection; -using Trax.Effect.Services.TrainLifecycleHook; -using Trax.Effect.Services.TrainLifecycleHookFactory; - -namespace Trax.Effect.Services.TrainEventBroadcaster; - -/// -/// Factory that creates instances via DI. -/// -public class BroadcastLifecycleHookFactory(IServiceProvider serviceProvider) - : ITrainLifecycleHookFactory -{ - public ITrainLifecycleHook Create() => - serviceProvider.GetRequiredService(); -} diff --git a/src/Trax.Effect/Services/TrainLifecycleHookFactory/LifecycleHookFactory.cs b/src/Trax.Effect/Services/TrainLifecycleHookFactory/LifecycleHookFactory.cs new file mode 100644 index 0000000..80baf44 --- /dev/null +++ b/src/Trax.Effect/Services/TrainLifecycleHookFactory/LifecycleHookFactory.cs @@ -0,0 +1,17 @@ +using Microsoft.Extensions.DependencyInjection; +using Trax.Effect.Services.TrainLifecycleHook; + +namespace Trax.Effect.Services.TrainLifecycleHookFactory; + +/// +/// Generic factory that creates lifecycle hook instances via DI. +/// Used internally by the AddLifecycleHook<THook>() overload +/// so that users don't need to write their own factory classes. +/// +public class LifecycleHookFactory(IServiceProvider serviceProvider) + : ITrainLifecycleHookFactory + where THook : class, ITrainLifecycleHook +{ + public ITrainLifecycleHook Create() => + ActivatorUtilities.CreateInstance(serviceProvider); +} diff --git a/tests/Trax.Effect.Tests.Broadcaster/UnitTests/BroadcasterRegistrationTests.cs b/tests/Trax.Effect.Tests.Broadcaster/UnitTests/BroadcasterRegistrationTests.cs index 41a0467..8a72327 100644 --- a/tests/Trax.Effect.Tests.Broadcaster/UnitTests/BroadcasterRegistrationTests.cs +++ b/tests/Trax.Effect.Tests.Broadcaster/UnitTests/BroadcasterRegistrationTests.cs @@ -115,7 +115,7 @@ public void UseBroadcaster_RegistersHookAsNonToggleable() ); // The broadcast hook should be registered and enabled (non-toggleable) - registry.IsEnabled(typeof(BroadcastLifecycleHookFactory)).Should().BeTrue(); + registry.IsEnabled(typeof(LifecycleHookFactory)).Should().BeTrue(); } [Test] diff --git a/tests/Trax.Effect.Tests.Integration/UnitTests/Services/AddLifecycleHookTests.cs b/tests/Trax.Effect.Tests.Integration/UnitTests/Services/AddLifecycleHookTests.cs index 828753b..2bec357 100644 --- a/tests/Trax.Effect.Tests.Integration/UnitTests/Services/AddLifecycleHookTests.cs +++ b/tests/Trax.Effect.Tests.Integration/UnitTests/Services/AddLifecycleHookTests.cs @@ -149,10 +149,155 @@ public void AddLifecycleHook_Multiple_AllRegistered() #endregion + #region Direct Hook Registration + + [Test] + public void AddLifecycleHook_DirectHook_RegistersFactoryAsSingleton() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => trax.AddEffects(effects => effects.AddLifecycleHook())); + using var provider = services.BuildServiceProvider(); + + var factories = provider.GetServices().ToList(); + + factories.Should().ContainSingle(f => f is LifecycleHookFactory); + } + + [Test] + public void AddLifecycleHook_DirectHook_RegistersInEffectRegistry() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => trax.AddEffects(effects => effects.AddLifecycleHook())); + using var provider = services.BuildServiceProvider(); + + var registry = provider.GetRequiredService(); + + registry.IsEnabled(typeof(LifecycleHookFactory)).Should().BeTrue(); + registry.IsToggleable(typeof(LifecycleHookFactory)).Should().BeTrue(); + } + + [Test] + public void AddLifecycleHook_DirectHook_NonToggleable() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => + trax.AddEffects(effects => effects.AddLifecycleHook(toggleable: false)) + ); + using var provider = services.BuildServiceProvider(); + + var registry = provider.GetRequiredService(); + + registry.IsToggleable(typeof(LifecycleHookFactory)).Should().BeFalse(); + } + + [Test] + public void AddLifecycleHook_DirectHook_FactoryCreatesCorrectHookType() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => trax.AddEffects(effects => effects.AddLifecycleHook())); + using var provider = services.BuildServiceProvider(); + + var factory = provider + .GetServices() + .Single(f => f is LifecycleHookFactory); + var hook = factory.Create(); + + hook.Should().BeOfType(); + } + + [Test] + public void AddLifecycleHook_DirectHookWithDependency_InjectsDependency() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddSingleton(); + services.AddTrax(trax => + trax.AddEffects(effects => effects.AddLifecycleHook()) + ); + using var provider = services.BuildServiceProvider(); + + var factory = provider + .GetServices() + .Single(f => f is LifecycleHookFactory); + var hook = factory.Create() as HookWithDependency; + + hook.Should().NotBeNull(); + hook!.Dependency.Should().BeOfType(); + } + + [Test] + public void AddLifecycleHook_MultipleDirectHooks_AllRegistered() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => + trax.AddEffects(effects => + effects.AddLifecycleHook().AddLifecycleHook() + ) + ); + using var provider = services.BuildServiceProvider(); + + var factories = provider.GetServices().ToList(); + + factories.Should().HaveCount(2); + factories.Should().Contain(f => f is LifecycleHookFactory); + factories.Should().Contain(f => f is LifecycleHookFactory); + } + + [Test] + public void AddLifecycleHook_MixedDirectAndFactory_AllRegistered() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => + trax.AddEffects(effects => + effects.AddLifecycleHook().AddLifecycleHook() + ) + ); + using var provider = services.BuildServiceProvider(); + + var factories = provider.GetServices().ToList(); + + factories.Should().HaveCount(2); + factories.Should().Contain(f => f is LifecycleHookFactory); + factories.Should().Contain(f => f is StubHookFactory); + } + + [Test] + public void AddLifecycleHook_DirectHook_RunnerResolvable() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddTrax(trax => trax.AddEffects(effects => effects.AddLifecycleHook())); + using var provider = services.BuildServiceProvider(); + + var runner = provider.GetService(); + + runner.Should().NotBeNull(); + } + + #endregion + #region Test Stubs private class StubHook : ITrainLifecycleHook { } + private class AnotherStubHook : ITrainLifecycleHook { } + + private interface IDependency { } + + private class FakeDependency : IDependency { } + + private class HookWithDependency(AddLifecycleHookTests.IDependency dependency) + : ITrainLifecycleHook + { + public IDependency Dependency => dependency; + } + private class StubHookFactory : ITrainLifecycleHookFactory { public ITrainLifecycleHook Create() => new StubHook();