Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/Trax.Mediator/Configuration/MediatorConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,37 @@ public class MediatorConfiguration
/// Takes precedence over <see cref="Trax.Effect.Attributes.TraxConcurrencyLimitAttribute"/> values.
/// </summary>
internal Dictionary<string, int> ConcurrencyOverrides { get; } = new();

/// <summary>
/// When <c>true</c>, the host may start without registering an
/// <c>ITrainAuthorizationService</c> even though some trains carry
/// <c>[TraxAuthorize]</c>. Intended for scheduler-only or dashboard-only processes
/// that never accept API submissions. Opt in via
/// <c>TraxMediatorBuilder.AllowMissingAuthorizationService()</c>.
/// </summary>
public bool AllowMissingAuthorizationService { get; internal set; }

/// <summary>
/// Maximum UTF-8 byte length for caller-supplied train input JSON in
/// <c>ITrainExecutionService.RunAsync</c> / <c>QueueAsync</c>. Defaults to
/// 256 KiB. Override via <c>TraxMediatorBuilder.WithMaxInputJsonBytes(int)</c>.
/// </summary>
/// <remarks>
/// The cap is enforced post-authorization but pre-deserialization so that
/// attacker-controlled JSON cannot exhaust memory or trigger deserializer
/// gadget chains before any fail-closed check has run. Queued work entries
/// created by <c>QueueAsync</c> are re-serialized from the parsed CLR object
/// and are therefore governed by the same cap indirectly.
/// </remarks>
public int MaxInputJsonBytes { get; internal set; } = 262_144;

/// <summary>
/// Maximum concurrent RUN executions per authenticated principal. When the
/// limit is reached, additional requests from the same principal queue on a
/// per-principal semaphore until an in-flight request completes. Defaults to
/// <c>null</c> (no per-principal cap — global and per-train limits still apply).
/// Requires <c>IHttpContextAccessor</c> to be registered; without it the cap
/// has no effect.
/// </summary>
public int? PerPrincipalMaxConcurrentRun { get; internal set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ internal MediatorConfiguration Build()
TrainLifetime = _lifetime,
Assemblies = [.. _assemblies],
GlobalMaxConcurrentRun = _globalMaxConcurrentRun,
AllowMissingAuthorizationService = _allowMissingAuthorizationService,
MaxInputJsonBytes = _maxInputJsonBytes,
PerPrincipalMaxConcurrentRun = _perPrincipalMaxConcurrentRun,
};

foreach (var (trainName, limit) in _concurrencyOverrides)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,66 @@ public TraxMediatorBuilder TrainLifetime(ServiceLifetime lifetime)
_lifetime = lifetime;
return this;
}

/// <summary>
/// Opts the host out of the startup check that fails when trains carry
/// <c>[TraxAuthorize]</c> but no <c>ITrainAuthorizationService</c> is registered.
/// Intended for processes that never serve API submissions (e.g. a standalone
/// scheduler, a dashboard-only host). Do NOT use from an API host.
/// </summary>
/// <remarks>
/// Calling this flips the fail-closed default in
/// <see cref="Services.TrainExecution.TrainExecutionService"/> to a silent no-op
/// when the authorization service is missing. Misapplication produces a process
/// that silently runs authorized trains without any authorization check.
/// </remarks>
public TraxMediatorBuilder AllowMissingAuthorizationService()
{
_allowMissingAuthorizationService = true;
return this;
}

/// <summary>
/// Sets the maximum UTF-8 byte length for caller-supplied input JSON in
/// <c>ITrainExecutionService.RunAsync</c> / <c>QueueAsync</c>. Default is
/// 256 KiB (262144 bytes). Inputs that exceed the cap are rejected with
/// <see cref="Exceptions.TrainInputValidationException"/> before any
/// deserialization runs.
/// </summary>
/// <param name="bytes">Maximum accepted size in bytes. Must be positive.</param>
public TraxMediatorBuilder WithMaxInputJsonBytes(int bytes)
{
if (bytes <= 0)
throw new ArgumentOutOfRangeException(
nameof(bytes),
bytes,
"MaxInputJsonBytes must be positive."
);
_maxInputJsonBytes = bytes;
return this;
}

/// <summary>
/// Caps the number of concurrent RUN executions for a single authenticated
/// principal (keyed by the <c>trax:principal-id</c> claim). Prevents a single
/// authenticated caller from saturating the global or per-train concurrency
/// budget via request fan-out. Default is <c>null</c> (no cap).
/// </summary>
/// <remarks>
/// Requires <c>IHttpContextAccessor</c> in DI (registered automatically by
/// <c>AddTraxApi</c>). Calls made without an HttpContext (scheduler, remote
/// worker, trusted scope) are not subject to the cap — those paths are
/// already gated by the global and per-train limits.
/// </remarks>
public TraxMediatorBuilder PerPrincipalMaxConcurrentRun(int limit)
{
if (limit <= 0)
throw new ArgumentOutOfRangeException(
nameof(limit),
limit,
"PerPrincipalMaxConcurrentRun must be positive."
);
_perPrincipalMaxConcurrentRun = limit;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ public partial class TraxMediatorBuilder
private readonly List<Assembly> _assemblies = [];
private int? _globalMaxConcurrentRun;
private readonly Dictionary<string, int> _concurrencyOverrides = new();
private bool _allowMissingAuthorizationService;
private int _maxInputJsonBytes = 262_144;
private int? _perPrincipalMaxConcurrentRun;

internal TraxMediatorBuilder(TraxBuilderWithEffects parent)
{
Expand Down
32 changes: 32 additions & 0 deletions src/Trax.Mediator/Exceptions/AmbiguousTrainNameException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
namespace Trax.Mediator.Exceptions;

/// <summary>
/// Thrown when a caller references a train by a friendly (short) name that matches
/// more than one registered train. Disambiguate by passing the interface
/// <see cref="Type.FullName"/> instead.
/// </summary>
/// <remarks>
/// Registered trains that share an unqualified class name across different namespaces
/// are allowed. The lookup only fails when the disambiguating name (<see cref="Type.FullName"/>)
/// is not provided. This exception is thrown as a fail-closed guard rather than picking
/// an arbitrary candidate.
/// </remarks>
public class AmbiguousTrainNameException : InvalidOperationException
{
public string RequestedName { get; }
public IReadOnlyList<string> CandidateFullNames { get; }

public AmbiguousTrainNameException(
string requestedName,
IReadOnlyList<string> candidateFullNames
)
: base(
$"Train name '{requestedName}' is ambiguous. Matching registrations: "
+ string.Join(", ", candidateFullNames)
+ ". Pass the interface FullName to disambiguate."
)
{
RequestedName = requestedName;
CandidateFullNames = candidateFullNames;
}
}
26 changes: 26 additions & 0 deletions src/Trax.Mediator/Exceptions/TrainInputValidationException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace Trax.Mediator.Exceptions;

/// <summary>
/// Thrown by <see cref="Services.TrainExecution.TrainExecutionService"/> when the
/// caller-supplied input JSON fails a pre-deserialization check (for example, it
/// exceeds the configured maximum size).
/// </summary>
/// <remarks>
/// The public <see cref="Exception.Message"/> is intentionally generic. Detailed
/// context (the offending size, the cap, etc.) is available on properties for
/// server-side logging.
/// </remarks>
public class TrainInputValidationException : InvalidOperationException
{
public string TrainName { get; }
public int ObservedBytes { get; }
public int MaxBytes { get; }

public TrainInputValidationException(string trainName, int observedBytes, int maxBytes)
: base("The train input failed validation.")
{
TrainName = trainName;
ObservedBytes = observedBytes;
MaxBytes = maxBytes;
}
}
22 changes: 22 additions & 0 deletions src/Trax.Mediator/Exceptions/TrainNotFoundException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace Trax.Mediator.Exceptions;

/// <summary>
/// Thrown by <see cref="Services.TrainExecution.TrainExecutionService"/> when a caller
/// references a train name that does not match any registered train.
/// </summary>
/// <remarks>
/// The public <see cref="Exception.Message"/> is intentionally generic and never
/// contains the requested train name. Enumerating registered trains via error
/// messages would let unauthenticated callers probe the API surface. Diagnostic
/// detail lives in <see cref="RequestedName"/> for server-side logging only.
/// </remarks>
public class TrainNotFoundException : InvalidOperationException
{
public string RequestedName { get; }

public TrainNotFoundException(string requestedName)
: base("The requested train was not found.")
{
RequestedName = requestedName;
}
}
10 changes: 10 additions & 0 deletions src/Trax.Mediator/Extensions/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
using System.Linq;
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Trax.Core.Exceptions;
using Trax.Effect.Configuration.TraxBuilder;
using Trax.Effect.Extensions;
using Trax.Effect.Services.ServiceTrain;
using Trax.Mediator.Configuration;
using Trax.Mediator.Services.ConcurrencyLimiter;
using Trax.Mediator.Services.Principal;
using Trax.Mediator.Services.RunExecutor;
using Trax.Mediator.Services.TrainAuthorization;
using Trax.Mediator.Services.TrainBus;
using Trax.Mediator.Services.TrainDiscovery;
using Trax.Mediator.Services.TrainExecution;
using Trax.Mediator.Services.TrainRegistry;
using Trax.Mediator.Services.TrustedExecution;

namespace Trax.Mediator.Extensions;

Expand Down Expand Up @@ -177,6 +181,12 @@ params Assembly[] assemblies
.AddSingleton<ITrainRegistry>(trainRegistry)
.AddSingleton<ITrainDiscoveryService, TrainDiscoveryService>()
.AddSingleton<IConcurrencyLimiter, ConcurrencyLimiter>()
.AddSingleton<ITrustedExecutionScope, TrustedExecutionScope>()
// Default null-returning principal provider. Hosts with an HTTP
// pipeline replace this via AddTraxApi with an HttpContext-backed
// implementation so per-principal concurrency caps activate.
.AddSingleton<ICurrentPrincipalProvider, NullPrincipalProvider>()
.AddHostedService<AuthorizationRegistrationValidator>()
.AddScoped<ITrainBus, TrainBus>()
.AddScoped<IRunExecutor, LocalRunExecutor>()
.AddScoped<ITrainExecutionService, TrainExecutionService>()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
using System.Collections.Concurrent;
using Trax.Mediator.Configuration;
using Trax.Mediator.Services.Principal;
using Trax.Mediator.Services.TrainDiscovery;

namespace Trax.Mediator.Services.ConcurrencyLimiter;

/// <summary>
/// Singleton service that manages per-train and global concurrency limits for RUN executions.
/// Uses <see cref="SemaphoreSlim"/> instances keyed by train interface FullName.
/// Singleton service that manages per-train, per-principal, and global concurrency
/// limits for RUN executions. Uses <see cref="SemaphoreSlim"/> instances keyed by
/// train interface FullName and (when applicable) principal id.
/// </summary>
public class ConcurrencyLimiter : IConcurrencyLimiter
{
private readonly MediatorConfiguration _configuration;
private readonly ITrainDiscoveryService _discoveryService;
private readonly ICurrentPrincipalProvider _principalProvider;
private readonly ConcurrentDictionary<string, Lazy<SemaphoreSlim?>> _perTrainSemaphores = new();
private readonly ConcurrentDictionary<string, SemaphoreSlim> _perPrincipalSemaphores = new();
private readonly SemaphoreSlim? _globalSemaphore;

public ConcurrencyLimiter(
MediatorConfiguration configuration,
ITrainDiscoveryService discoveryService
ITrainDiscoveryService discoveryService,
ICurrentPrincipalProvider principalProvider
)
{
_configuration = configuration;
_discoveryService = discoveryService;
_principalProvider = principalProvider;
_globalSemaphore = configuration.GlobalMaxConcurrentRun is { } globalLimit
? new SemaphoreSlim(globalLimit, globalLimit)
: null;
Expand All @@ -30,24 +36,37 @@ ITrainDiscoveryService discoveryService
public async Task<IDisposable> AcquireAsync(string trainFullName, CancellationToken ct)
{
var perTrainSemaphore = GetOrCreatePerTrainSemaphore(trainFullName);
var perPrincipalSemaphore = GetOrCreatePerPrincipalSemaphore();

// Acquire per-train first, then global — deterministic order prevents deadlocks
// Acquire in a deterministic order (per-train → per-principal → global)
// to prevent cross-lock deadlocks. Release in reverse.
if (perTrainSemaphore is not null)
await perTrainSemaphore.WaitAsync(ct);

try
{
if (perPrincipalSemaphore is not null)
await perPrincipalSemaphore.WaitAsync(ct);
}
catch
{
perTrainSemaphore?.Release();
throw;
}

try
{
if (_globalSemaphore is not null)
await _globalSemaphore.WaitAsync(ct);
}
catch
{
// If global acquire fails (cancellation), release the per-train permit we already hold
perPrincipalSemaphore?.Release();
perTrainSemaphore?.Release();
throw;
}

return new ConcurrencyPermit(perTrainSemaphore, _globalSemaphore);
return new ConcurrencyPermit(perTrainSemaphore, perPrincipalSemaphore, _globalSemaphore);
}

private SemaphoreSlim? GetOrCreatePerTrainSemaphore(string trainFullName)
Expand All @@ -64,6 +83,18 @@ public async Task<IDisposable> AcquireAsync(string trainFullName, CancellationTo
.Value;
}

private SemaphoreSlim? GetOrCreatePerPrincipalSemaphore()
{
if (_configuration.PerPrincipalMaxConcurrentRun is not { } limit)
return null;

var principalId = _principalProvider.GetCurrentPrincipalId();
if (string.IsNullOrEmpty(principalId))
return null;

return _perPrincipalSemaphores.GetOrAdd(principalId, _ => new SemaphoreSlim(limit, limit));
}

private int? ResolveLimit(string trainFullName)
{
// Priority 1: Builder override
Expand All @@ -78,8 +109,11 @@ public async Task<IDisposable> AcquireAsync(string trainFullName, CancellationTo
return registration?.MaxConcurrentRun;
}

private sealed class ConcurrencyPermit(SemaphoreSlim? perTrain, SemaphoreSlim? global)
: IDisposable
private sealed class ConcurrencyPermit(
SemaphoreSlim? perTrain,
SemaphoreSlim? perPrincipal,
SemaphoreSlim? global
) : IDisposable
{
private int _disposed;

Expand All @@ -90,6 +124,7 @@ public void Dispose()

// Release in reverse order of acquisition
global?.Release();
perPrincipal?.Release();
perTrain?.Release();
}
}
Expand Down
24 changes: 24 additions & 0 deletions src/Trax.Mediator/Services/Principal/ICurrentPrincipalProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
namespace Trax.Mediator.Services.Principal;

/// <summary>
/// Abstracts "who is the current authenticated caller" for mediator-level
/// services that need it (e.g. per-principal concurrency limits). The mediator
/// itself does not depend on ASP.NET Core; hosts that run inside an HTTP pipeline
/// register an implementation that reads the principal id from an HttpContext.
/// </summary>
/// <remarks>
/// Implementations should be cheap and thread-safe; a call-site may invoke
/// <see cref="GetCurrentPrincipalId"/> on the hot path. Return <c>null</c> when
/// no authenticated principal is available (scheduler, remote worker, anonymous
/// route).
/// </remarks>
public interface ICurrentPrincipalProvider
{
/// <summary>
/// The stable identifier of the current authenticated principal, or
/// <c>null</c> when no principal is present. Mediator-level services use
/// this as a bucketing key; the exact semantics (claim name, shape) are
/// the host's responsibility.
/// </summary>
string? GetCurrentPrincipalId();
}
12 changes: 12 additions & 0 deletions src/Trax.Mediator/Services/Principal/NullPrincipalProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace Trax.Mediator.Services.Principal;

/// <summary>
/// Default <see cref="ICurrentPrincipalProvider"/> that always returns <c>null</c>.
/// Used in hosts that do not run inside an HTTP request pipeline (scheduler,
/// remote worker). Replaced by <c>AddTraxApi</c> with an HttpContext-backed
/// implementation.
/// </summary>
internal sealed class NullPrincipalProvider : ICurrentPrincipalProvider
{
public string? GetCurrentPrincipalId() => null;
}
Loading
Loading