Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/Trax.Effect/Services/ServiceTrain/ServiceTrain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,23 @@ public override async Task<TOut> Run(TIn input, CancellationToken cancellationTo
Metadata.AssertLoaded();
await EffectRunner.SaveChanges(CancellationToken);

await LifecycleHookRunner.OnStarted(Metadata, CancellationToken);

try
{
await OnStarted(Metadata, CancellationToken);
}
catch (Exception hookEx)
{
Logger?.LogError(
hookEx,
"Train-level OnStarted hook threw for train ({TrainName}).",
TrainName
);
}
await LifecycleHookRunner.OnStarted(Metadata, CancellationToken);

try
{
await OnStarted(Metadata, CancellationToken);
}
catch (Exception hookEx)
{
Logger?.LogError(
hookEx,
"Train-level OnStarted hook threw for train ({TrainName}).",
TrainName
);
}

try
{
Logger?.LogTrace("Running Train: ({TrainName})", TrainName);
Metadata.SetInputObject(input);
var result = await RunInternal(input);
Expand Down Expand Up @@ -332,7 +332,7 @@ public override async Task<TOut> Run(TIn input, CancellationToken cancellationTo
public virtual async Task<TOut> Run(TIn input, Metadata metadata)
{
await this.InitializeServiceTrain(metadata);
return await Run(input);
return await Run(input, CancellationToken);
}

/// <summary>
Expand All @@ -346,7 +346,7 @@ CancellationToken cancellationToken
{
CancellationToken = cancellationToken;
await this.InitializeServiceTrain(metadata);
return await Run(input);
return await Run(input, CancellationToken);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Trax.Core.Exceptions;
using Trax.Effect.Extensions;
using Trax.Effect.Models.Metadata;
using Trax.Effect.Models.Metadata.DTOs;
using Trax.Effect.Services.ServiceTrain;
using Trax.Effect.Tests.Data.InMemory.Integration.Fixtures;

Expand Down Expand Up @@ -243,6 +244,93 @@ public async Task Run_CancellationTokenPassedToHooks()
train.CompletedCancellationToken.Should().Be(cts.Token);
}

[Test]
public async Task Run_WithMetadataAndCancellationToken_TokenPreservedThroughHooks()
{
var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService<IRecordingTrain>();
using var cts = new CancellationTokenSource();
var metadata = Metadata.Create(
new CreateMetadata
{
Name = typeof(IRecordingTrain).FullName!,
ExternalId = Guid.NewGuid().ToString("N"),
Input = null,
}
);

await train.Run(Unit.Default, metadata, cts.Token);

train.StartedCancellationToken.Should().Be(cts.Token);
train.CompletedCancellationToken.Should().Be(cts.Token);
}

[Test]
public async Task Run_WithMetadataAndCancellationToken_TokenNotOverwrittenByDefault()
{
// Regression test: previously Run(input, metadata, ct) called Run(input)
// which overwrote the token with CancellationToken.None
var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService<IRecordingTrain>();
using var cts = new CancellationTokenSource();
var metadata = Metadata.Create(
new CreateMetadata
{
Name = typeof(IRecordingTrain).FullName!,
ExternalId = Guid.NewGuid().ToString("N"),
Input = null,
}
);

await train.Run(Unit.Default, metadata, cts.Token);

train
.StartedCancellationToken.Should()
.NotBe(
CancellationToken.None,
"the real token should survive through to hooks, not be overwritten with default"
);
}

[Test]
public async Task Run_WithMetadataOnly_DoesNotClobberToken()
{
// Run(input, metadata) should pass CancellationToken through without issue
var train = (RecordingTrain)Scope.ServiceProvider.GetRequiredService<IRecordingTrain>();
var metadata = Metadata.Create(
new CreateMetadata
{
Name = typeof(IRecordingTrain).FullName!,
ExternalId = Guid.NewGuid().ToString("N"),
Input = null,
}
);

await train.Run(Unit.Default, metadata);

train.StartedCalled.Should().BeTrue();
train.CompletedCalled.Should().BeTrue();
}

[Test]
public async Task Run_CancelledDuringExecution_WithMetadataOverload_CallsOnCancelled()
{
var train = (CancellingRecordingTrain)
Scope.ServiceProvider.GetRequiredService<ICancellingRecordingTrain>();
using var cts = new CancellationTokenSource();
var metadata = Metadata.Create(
new CreateMetadata
{
Name = typeof(ICancellingRecordingTrain).FullName!,
ExternalId = Guid.NewGuid().ToString("N"),
Input = null,
}
);

var act = async () => await train.Run(Unit.Default, metadata, cts.Token);
await act.Should().ThrowAsync<OperationCanceledException>();

train.CancelledCalled.Should().BeTrue();
}

#endregion

#region OutputSerialization
Expand Down
Loading