diff --git a/src/Trax.Scheduler.Lambda/Configuration/LambdaRetryOptions.cs b/src/Trax.Scheduler.Lambda/Configuration/LambdaRetryOptions.cs
new file mode 100644
index 0000000..7f37ffa
--- /dev/null
+++ b/src/Trax.Scheduler.Lambda/Configuration/LambdaRetryOptions.cs
@@ -0,0 +1,29 @@
+namespace Trax.Scheduler.Lambda.Configuration;
+
+///
+/// Retry options for transient AWS Lambda invocation failures (throttling, service errors).
+///
+///
+/// Applied to and .
+/// Retries on AWS status codes 429 (Throttling), 502 (Bad Gateway), 503 (Service Unavailable),
+/// and 504 (Gateway Timeout) with exponential backoff and jitter.
+/// Set to 0 to disable retries.
+///
+public class LambdaRetryOptions
+{
+ ///
+ /// Maximum number of retry attempts before giving up.
+ ///
+ public int MaxRetries { get; set; } = 5;
+
+ ///
+ /// Base delay between retries. Actual delay is BaseDelay * 2^attempt with jitter,
+ /// capped at .
+ ///
+ public TimeSpan BaseDelay { get; set; } = TimeSpan.FromSeconds(1);
+
+ ///
+ /// Maximum delay between retries, preventing unbounded exponential growth.
+ ///
+ public TimeSpan MaxDelay { get; set; } = TimeSpan.FromSeconds(30);
+}
diff --git a/src/Trax.Scheduler.Lambda/Configuration/LambdaRunOptions.cs b/src/Trax.Scheduler.Lambda/Configuration/LambdaRunOptions.cs
index a96a352..5087f3b 100644
--- a/src/Trax.Scheduler.Lambda/Configuration/LambdaRunOptions.cs
+++ b/src/Trax.Scheduler.Lambda/Configuration/LambdaRunOptions.cs
@@ -25,4 +25,9 @@ public class LambdaRunOptions
/// Use this to set a custom region, endpoint override (e.g., LocalStack), or service URL.
///
public Action? ConfigureLambdaClient { get; set; }
+
+ ///
+ /// Retry options for transient Lambda invocation failures (throttling, service errors).
+ ///
+ public LambdaRetryOptions Retry { get; set; } = new();
}
diff --git a/src/Trax.Scheduler.Lambda/Configuration/LambdaWorkerOptions.cs b/src/Trax.Scheduler.Lambda/Configuration/LambdaWorkerOptions.cs
index 1dc9d05..aec9cc6 100644
--- a/src/Trax.Scheduler.Lambda/Configuration/LambdaWorkerOptions.cs
+++ b/src/Trax.Scheduler.Lambda/Configuration/LambdaWorkerOptions.cs
@@ -27,4 +27,9 @@ public class LambdaWorkerOptions
/// Use this to set a custom region, endpoint override (e.g., LocalStack), or service URL.
///
public Action? ConfigureLambdaClient { get; set; }
+
+ ///
+ /// Retry options for transient Lambda invocation failures (throttling, service errors).
+ ///
+ public LambdaRetryOptions Retry { get; set; } = new();
}
diff --git a/src/Trax.Scheduler.Lambda/Services/LambdaJobSubmitter.cs b/src/Trax.Scheduler.Lambda/Services/LambdaJobSubmitter.cs
index 74ff2da..d283f91 100644
--- a/src/Trax.Scheduler.Lambda/Services/LambdaJobSubmitter.cs
+++ b/src/Trax.Scheduler.Lambda/Services/LambdaJobSubmitter.cs
@@ -81,7 +81,13 @@ private async Task InvokeAsync(RemoteJobRequest request, CancellationToken cance
request.MetadataId
);
- var response = await lambdaClient.InvokeAsync(invokeRequest, cancellationToken);
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ lambdaClient,
+ invokeRequest,
+ options.Retry,
+ logger,
+ cancellationToken
+ );
if (!string.IsNullOrEmpty(response.FunctionError))
{
diff --git a/src/Trax.Scheduler.Lambda/Services/LambdaRetryHelper.cs b/src/Trax.Scheduler.Lambda/Services/LambdaRetryHelper.cs
new file mode 100644
index 0000000..5d1a6b0
--- /dev/null
+++ b/src/Trax.Scheduler.Lambda/Services/LambdaRetryHelper.cs
@@ -0,0 +1,113 @@
+using System.Net;
+using Amazon.Lambda;
+using Amazon.Lambda.Model;
+using Amazon.Runtime;
+using Microsoft.Extensions.Logging;
+using Trax.Scheduler.Lambda.Configuration;
+
+namespace Trax.Scheduler.Lambda.Services;
+
+///
+/// Retries AWS Lambda invocations on transient failures (throttling, 502, 503, 504) with exponential backoff and jitter.
+///
+internal static class LambdaRetryHelper
+{
+ private static readonly HashSet TransientStatusCodes =
+ [
+ HttpStatusCode.TooManyRequests,
+ HttpStatusCode.BadGateway,
+ HttpStatusCode.ServiceUnavailable,
+ HttpStatusCode.GatewayTimeout,
+ ];
+
+ ///
+ /// Invokes a Lambda function with retry logic for transient AWS failures.
+ ///
+ /// The Lambda client to use.
+ /// The invocation request.
+ /// Retry configuration.
+ /// Logger for retry diagnostics.
+ /// Cancellation token.
+ /// The invoke response (either successful or from the last attempt before a non-transient failure).
+ internal static async Task InvokeWithRetryAsync(
+ IAmazonLambda client,
+ InvokeRequest request,
+ LambdaRetryOptions options,
+ ILogger? logger,
+ CancellationToken ct
+ )
+ {
+ var maxRetries = Math.Max(0, options.MaxRetries);
+
+ Exception? lastException = null;
+
+ for (var attempt = 0; attempt <= maxRetries; attempt++)
+ {
+ try
+ {
+ return await client.InvokeAsync(request, ct);
+ }
+ catch (Exception ex) when (IsTransient(ex))
+ {
+ lastException = ex;
+
+ if (attempt == maxRetries)
+ break;
+
+ var delay = ComputeDelay(attempt, options);
+
+ logger?.LogWarning(
+ "Lambda invocation failed with transient error ({ErrorType}), retrying in {DelayMs}ms (attempt {Attempt}/{MaxRetries})",
+ ex.GetType().Name,
+ delay.TotalMilliseconds,
+ attempt + 1,
+ maxRetries
+ );
+
+ await Task.Delay(delay, ct);
+ }
+ }
+
+ throw lastException!;
+ }
+
+ ///
+ /// Determines whether an exception represents a transient AWS failure that should be retried.
+ ///
+ internal static bool IsTransient(Exception ex)
+ {
+ if (ex is AmazonServiceException serviceException)
+ return TransientStatusCodes.Contains(serviceException.StatusCode);
+
+ if (ex is HttpRequestException)
+ return true;
+
+ return false;
+ }
+
+ ///
+ /// Computes the delay for a given retry attempt using exponential backoff with jitter.
+ ///
+ internal static TimeSpan ComputeDelay(int attempt, LambdaRetryOptions options)
+ {
+ // Exponential backoff: baseDelay * 2^attempt
+ var exponentialMs = options.BaseDelay.TotalMilliseconds * Math.Pow(2, attempt);
+
+ // Add jitter: +/-25%
+ var jitterFactor = 0.75 + Random.Shared.NextDouble() * 0.5;
+ var delayMs = exponentialMs * jitterFactor;
+
+ return Clamp(TimeSpan.FromMilliseconds(delayMs), options);
+ }
+
+ private static TimeSpan Clamp(TimeSpan delay, LambdaRetryOptions options)
+ {
+ if (delay < TimeSpan.Zero)
+ return options.BaseDelay;
+
+ if (delay > options.MaxDelay)
+ return options.MaxDelay;
+
+ return delay;
+ }
+}
diff --git a/src/Trax.Scheduler.Lambda/Services/LambdaRunExecutor.cs b/src/Trax.Scheduler.Lambda/Services/LambdaRunExecutor.cs
index 6bd2be4..7710948 100644
--- a/src/Trax.Scheduler.Lambda/Services/LambdaRunExecutor.cs
+++ b/src/Trax.Scheduler.Lambda/Services/LambdaRunExecutor.cs
@@ -63,7 +63,13 @@ public async Task ExecuteAsync(
trainName
);
- var invokeResponse = await lambdaClient.InvokeAsync(invokeRequest, ct);
+ var invokeResponse = await LambdaRetryHelper.InvokeWithRetryAsync(
+ lambdaClient,
+ invokeRequest,
+ options.Retry,
+ logger,
+ ct
+ );
if (!string.IsNullOrEmpty(invokeResponse.FunctionError))
{
diff --git a/src/Trax.Scheduler.Lambda/Trax.Scheduler.Lambda.csproj b/src/Trax.Scheduler.Lambda/Trax.Scheduler.Lambda.csproj
index 3907645..5f53362 100644
--- a/src/Trax.Scheduler.Lambda/Trax.Scheduler.Lambda.csproj
+++ b/src/Trax.Scheduler.Lambda/Trax.Scheduler.Lambda.csproj
@@ -9,6 +9,11 @@
Trax.Scheduler.Lambda
+
+
+
+
+
diff --git a/tests/Trax.Scheduler.Tests.Integration/UnitTests/LambdaJobSubmitterTests.cs b/tests/Trax.Scheduler.Tests.Integration/UnitTests/LambdaJobSubmitterTests.cs
index cf56f89..4841375 100644
--- a/tests/Trax.Scheduler.Tests.Integration/UnitTests/LambdaJobSubmitterTests.cs
+++ b/tests/Trax.Scheduler.Tests.Integration/UnitTests/LambdaJobSubmitterTests.cs
@@ -1,6 +1,8 @@
+using System.Net;
using System.Text.Json;
using Amazon.Lambda;
using Amazon.Lambda.Model;
+using Amazon.Runtime;
using FluentAssertions;
using Trax.Core.Exceptions;
using Trax.Effect.Utils;
@@ -282,6 +284,43 @@ public async Task EnqueueAsync_MultipleCalls_EachProducesUniqueJobId()
#endregion
+ #region Retry Tests
+
+ [Test]
+ public async Task EnqueueAsync_ThrottleThenSuccess_RetriesTransparentlyAndReturnsJobId()
+ {
+ // Arrange
+ var options = new LambdaWorkerOptions
+ {
+ FunctionName = "my-lambda-function",
+ Retry = new LambdaRetryOptions
+ {
+ MaxRetries = 3,
+ BaseDelay = TimeSpan.FromMilliseconds(1),
+ },
+ };
+ var client = new MockLambdaClient();
+ client.ExceptionsBeforeSuccess.Enqueue(
+ new AmazonServiceException("Throttled") { StatusCode = HttpStatusCode.TooManyRequests }
+ );
+ var logger = Microsoft
+ .Extensions
+ .Logging
+ .Abstractions
+ .NullLogger
+ .Instance;
+ var submitter = new LambdaJobSubmitter(client, options, logger);
+
+ // Act
+ var jobId = await submitter.EnqueueAsync(42);
+
+ // Assert
+ jobId.Should().StartWith("lambda-");
+ client.AllRequests.Should().HaveCount(2); // 1 throttled + 1 success
+ }
+
+ #endregion
+
#region MockLambdaClient
internal class MockLambdaClient : IAmazonLambda
@@ -291,6 +330,12 @@ internal class MockLambdaClient : IAmazonLambda
public bool ThrowOnInvoke { get; set; }
public string? FunctionError { get; set; }
+ ///
+ /// Optional queue of exceptions to throw before returning a successful response.
+ /// Each invocation pops the next exception; once empty, returns normally.
+ ///
+ public Queue ExceptionsBeforeSuccess { get; } = new();
+
public Amazon.Runtime.IClientConfig Config => throw new NotImplementedException();
public ILambdaPaginatorFactory Paginators => throw new NotImplementedException();
@@ -302,11 +347,15 @@ public Task InvokeAsync(
{
cancellationToken.ThrowIfCancellationRequested();
+ AllRequests.Add(request);
+
+ if (ExceptionsBeforeSuccess.Count > 0)
+ throw ExceptionsBeforeSuccess.Dequeue();
+
if (ThrowOnInvoke)
throw new AmazonLambdaException("Mock Lambda error");
LastRequest = request;
- AllRequests.Add(request);
return Task.FromResult(
new InvokeResponse { StatusCode = 202, FunctionError = FunctionError }
diff --git a/tests/Trax.Scheduler.Tests/UnitTests/LambdaRetryHelperTests.cs b/tests/Trax.Scheduler.Tests/UnitTests/LambdaRetryHelperTests.cs
new file mode 100644
index 0000000..ec240d6
--- /dev/null
+++ b/tests/Trax.Scheduler.Tests/UnitTests/LambdaRetryHelperTests.cs
@@ -0,0 +1,951 @@
+using System.Net;
+using Amazon.Lambda;
+using Amazon.Lambda.Model;
+using Amazon.Runtime;
+using FluentAssertions;
+using Microsoft.Extensions.Logging.Abstractions;
+using Trax.Scheduler.Lambda.Configuration;
+using Trax.Scheduler.Lambda.Services;
+
+namespace Trax.Scheduler.Tests.UnitTests;
+
+[TestFixture]
+public class LambdaRetryHelperTests
+{
+ private static LambdaRetryOptions FastOptions(int maxRetries = 3) =>
+ new()
+ {
+ MaxRetries = maxRetries,
+ BaseDelay = TimeSpan.FromMilliseconds(1),
+ MaxDelay = TimeSpan.FromSeconds(30),
+ };
+
+ private static InvokeRequest DefaultRequest =>
+ new() { FunctionName = "test-fn", Payload = "{}" };
+
+ #region Retry on Transient Exceptions
+
+ [Test]
+ public async Task InvokeWithRetryAsync_429ThenSuccess_RetriesAndReturnsResponse()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.TooManyRequests),
+ Succeed(),
+ ]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(2);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_502ThenSuccess_RetriesAndReturnsResponse()
+ {
+ var client = new SequentialMockLambdaClient([Throw(HttpStatusCode.BadGateway), Succeed()]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(2);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_503ThenSuccess_RetriesAndReturnsResponse()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.ServiceUnavailable),
+ Succeed(),
+ ]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(2);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_504ThenSuccess_RetriesAndReturnsResponse()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.GatewayTimeout),
+ Succeed(),
+ ]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(2);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_HttpRequestExceptionThenSuccess_Retries()
+ {
+ var client = new SequentialMockLambdaClient([ThrowHttp("connection reset"), Succeed()]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(2);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_Multiple429ThenSuccess_RetriesMultipleTimes()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.TooManyRequests),
+ Throw(HttpStatusCode.TooManyRequests),
+ Throw(HttpStatusCode.TooManyRequests),
+ Succeed(),
+ ]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(5),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(4);
+ }
+
+ #endregion
+
+ #region Exhaust Retries
+
+ [Test]
+ public async Task InvokeWithRetryAsync_429ExceedsMaxRetries_ThrowsLastException()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.TooManyRequests),
+ Throw(HttpStatusCode.TooManyRequests),
+ Throw(HttpStatusCode.TooManyRequests),
+ Throw(HttpStatusCode.TooManyRequests),
+ ]);
+
+ var act = async () =>
+ await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(3),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ await act.Should().ThrowAsync();
+ client.InvokeCount.Should().Be(4); // 1 initial + 3 retries
+ }
+
+ #endregion
+
+ #region Non-Transient Exceptions
+
+ [Test]
+ public async Task InvokeWithRetryAsync_ResourceNotFound_DoesNotRetry()
+ {
+ var client = new SequentialMockLambdaClient([
+ ThrowCustom(new ResourceNotFoundException("not found")),
+ ]);
+
+ var act = async () =>
+ await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ await act.Should().ThrowAsync();
+ client.InvokeCount.Should().Be(1);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_InvalidParameterValue_DoesNotRetry()
+ {
+ var client = new SequentialMockLambdaClient([
+ ThrowCustom(new InvalidParameterValueException("bad param")),
+ ]);
+
+ var act = async () =>
+ await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ await act.Should().ThrowAsync();
+ client.InvokeCount.Should().Be(1);
+ }
+
+ [Test]
+ public async Task InvokeWithRetryAsync_500InternalServerError_DoesNotRetry()
+ {
+ var client = new SequentialMockLambdaClient([Throw(HttpStatusCode.InternalServerError)]);
+
+ var act = async () =>
+ await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ await act.Should().ThrowAsync();
+ client.InvokeCount.Should().Be(1);
+ }
+
+ #endregion
+
+ #region MaxRetries = 0 (Disabled)
+
+ [Test]
+ public async Task InvokeWithRetryAsync_MaxRetriesZero_DoesNotRetry()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.TooManyRequests),
+ Succeed(),
+ ]);
+
+ var act = async () =>
+ await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(0),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ await act.Should().ThrowAsync();
+ client.InvokeCount.Should().Be(1);
+ }
+
+ #endregion
+
+ #region Success Without Retry
+
+ [Test]
+ public async Task InvokeWithRetryAsync_ImmediateSuccess_DoesNotRetry()
+ {
+ var client = new SequentialMockLambdaClient([Succeed()]);
+
+ var response = await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ FastOptions(),
+ NullLogger.Instance,
+ CancellationToken.None
+ );
+
+ response.StatusCode.Should().Be(200);
+ client.InvokeCount.Should().Be(1);
+ }
+
+ #endregion
+
+ #region ComputeDelay
+
+ [Test]
+ public void ComputeDelay_Attempt0_ReturnsAroundBaseDelay()
+ {
+ var options = new LambdaRetryOptions
+ {
+ BaseDelay = TimeSpan.FromSeconds(1),
+ MaxDelay = TimeSpan.FromSeconds(30),
+ };
+
+ var delay = LambdaRetryHelper.ComputeDelay(0, options);
+
+ // Base delay * 2^0 = 1s, with +/-25% jitter -> 0.75s to 1.25s
+ delay.TotalMilliseconds.Should().BeInRange(750, 1250);
+ }
+
+ [Test]
+ public void ComputeDelay_Attempt3_ReturnsExponentiallyHigher()
+ {
+ var options = new LambdaRetryOptions
+ {
+ BaseDelay = TimeSpan.FromSeconds(1),
+ MaxDelay = TimeSpan.FromSeconds(30),
+ };
+
+ var delay = LambdaRetryHelper.ComputeDelay(3, options);
+
+ // Base delay * 2^3 = 8s, with +/-25% jitter -> 6s to 10s
+ delay.TotalMilliseconds.Should().BeInRange(6000, 10000);
+ }
+
+ [Test]
+ public void ComputeDelay_LargeAttempt_CapsAtMaxDelay()
+ {
+ var options = new LambdaRetryOptions
+ {
+ BaseDelay = TimeSpan.FromSeconds(1),
+ MaxDelay = TimeSpan.FromSeconds(30),
+ };
+
+ var delay = LambdaRetryHelper.ComputeDelay(10, options);
+
+ delay.Should().BeLessThanOrEqualTo(TimeSpan.FromSeconds(30));
+ }
+
+ [Test]
+ public void ComputeDelay_JitterApplied_ProducesDifferentValues()
+ {
+ var options = new LambdaRetryOptions
+ {
+ BaseDelay = TimeSpan.FromSeconds(1),
+ MaxDelay = TimeSpan.FromSeconds(30),
+ };
+
+ var delays = Enumerable
+ .Range(0, 20)
+ .Select(_ => LambdaRetryHelper.ComputeDelay(2, options))
+ .ToList();
+
+ delays.Distinct().Count().Should().BeGreaterThan(1);
+ }
+
+ #endregion
+
+ #region IsTransient
+
+ [Test]
+ public void IsTransient_429_ReturnsTrue()
+ {
+ var ex = new AmazonServiceException("throttle")
+ {
+ StatusCode = HttpStatusCode.TooManyRequests,
+ };
+ LambdaRetryHelper.IsTransient(ex).Should().BeTrue();
+ }
+
+ [Test]
+ public void IsTransient_502_ReturnsTrue()
+ {
+ var ex = new AmazonServiceException("bad gateway")
+ {
+ StatusCode = HttpStatusCode.BadGateway,
+ };
+ LambdaRetryHelper.IsTransient(ex).Should().BeTrue();
+ }
+
+ [Test]
+ public void IsTransient_503_ReturnsTrue()
+ {
+ var ex = new AmazonServiceException("unavailable")
+ {
+ StatusCode = HttpStatusCode.ServiceUnavailable,
+ };
+ LambdaRetryHelper.IsTransient(ex).Should().BeTrue();
+ }
+
+ [Test]
+ public void IsTransient_504_ReturnsTrue()
+ {
+ var ex = new AmazonServiceException("timeout")
+ {
+ StatusCode = HttpStatusCode.GatewayTimeout,
+ };
+ LambdaRetryHelper.IsTransient(ex).Should().BeTrue();
+ }
+
+ [Test]
+ public void IsTransient_HttpRequestException_ReturnsTrue()
+ {
+ var ex = new HttpRequestException("connection reset");
+ LambdaRetryHelper.IsTransient(ex).Should().BeTrue();
+ }
+
+ [Test]
+ public void IsTransient_500_ReturnsFalse()
+ {
+ var ex = new AmazonServiceException("internal error")
+ {
+ StatusCode = HttpStatusCode.InternalServerError,
+ };
+ LambdaRetryHelper.IsTransient(ex).Should().BeFalse();
+ }
+
+ [Test]
+ public void IsTransient_ResourceNotFoundException_ReturnsFalse()
+ {
+ var ex = new ResourceNotFoundException("not found");
+ LambdaRetryHelper.IsTransient(ex).Should().BeFalse();
+ }
+
+ [Test]
+ public void IsTransient_InvalidOperationException_ReturnsFalse()
+ {
+ var ex = new InvalidOperationException("bad state");
+ LambdaRetryHelper.IsTransient(ex).Should().BeFalse();
+ }
+
+ #endregion
+
+ #region Cancellation
+
+ [Test]
+ public async Task InvokeWithRetryAsync_CancelledDuringRetry_ThrowsOperationCancelled()
+ {
+ var client = new SequentialMockLambdaClient([
+ Throw(HttpStatusCode.TooManyRequests),
+ Throw(HttpStatusCode.TooManyRequests),
+ Succeed(),
+ ]);
+
+ var options = new LambdaRetryOptions
+ {
+ MaxRetries = 5,
+ BaseDelay = TimeSpan.FromSeconds(10),
+ MaxDelay = TimeSpan.FromSeconds(30),
+ };
+
+ using var cts = new CancellationTokenSource();
+ cts.CancelAfter(TimeSpan.FromMilliseconds(50));
+
+ var act = async () =>
+ await LambdaRetryHelper.InvokeWithRetryAsync(
+ client,
+ DefaultRequest,
+ options,
+ NullLogger.Instance,
+ cts.Token
+ );
+
+ await act.Should().ThrowAsync();
+ }
+
+ #endregion
+
+ #region Helpers
+
+ private static Func Succeed() => () => new InvokeResponse { StatusCode = 200 };
+
+ private static Func Throw(HttpStatusCode statusCode) =>
+ () =>
+ throw new AmazonServiceException($"AWS error {statusCode}") { StatusCode = statusCode };
+
+ private static Func ThrowHttp(string message) =>
+ () => throw new HttpRequestException(message);
+
+ private static Func ThrowCustom(Exception ex) => () => throw ex;
+
+ #endregion
+
+ #region SequentialMockLambdaClient
+
+ private class SequentialMockLambdaClient(List> behaviors) : IAmazonLambda
+ {
+ private int _callIndex;
+ public int InvokeCount => _callIndex;
+
+ public IClientConfig Config => throw new NotImplementedException();
+ public ILambdaPaginatorFactory Paginators => throw new NotImplementedException();
+
+ public Task InvokeAsync(
+ InvokeRequest request,
+ CancellationToken cancellationToken = default
+ )
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ var index = _callIndex < behaviors.Count ? _callIndex : behaviors.Count - 1;
+ _callIndex++;
+
+ var response = behaviors[index]();
+ return Task.FromResult(response);
+ }
+
+ // Minimal interface stubs
+ public Task InvokeAsync(
+ string functionName,
+ CancellationToken cancellationToken = default
+ ) => throw new NotImplementedException();
+
+ public Task AddLayerVersionPermissionAsync(
+ AddLayerVersionPermissionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task AddPermissionAsync(
+ AddPermissionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CreateAliasAsync(
+ CreateAliasRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CreateCodeSigningConfigAsync(
+ CreateCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CreateEventSourceMappingAsync(
+ CreateEventSourceMappingRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CreateFunctionAsync(
+ CreateFunctionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CreateFunctionUrlConfigAsync(
+ CreateFunctionUrlConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteAliasAsync(
+ DeleteAliasRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteCodeSigningConfigAsync(
+ DeleteCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteEventSourceMappingAsync(
+ DeleteEventSourceMappingRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteFunctionAsync(
+ string fn,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteFunctionAsync(
+ DeleteFunctionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteFunctionCodeSigningConfigAsync(
+ DeleteFunctionCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteFunctionConcurrencyAsync(
+ DeleteFunctionConcurrencyRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteFunctionEventInvokeConfigAsync(
+ DeleteFunctionEventInvokeConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteFunctionUrlConfigAsync(
+ DeleteFunctionUrlConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteLayerVersionAsync(
+ DeleteLayerVersionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteProvisionedConcurrencyConfigAsync(
+ DeleteProvisionedConcurrencyConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Amazon.Runtime.Endpoints.Endpoint DetermineServiceOperationEndpoint(
+ AmazonWebServiceRequest r
+ ) => throw new NotImplementedException();
+
+ public Task GetAccountSettingsAsync(
+ GetAccountSettingsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetAliasAsync(
+ GetAliasRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetCodeSigningConfigAsync(
+ GetCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetEventSourceMappingAsync(
+ GetEventSourceMappingRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionAsync(
+ string fn,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionAsync(
+ GetFunctionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionCodeSigningConfigAsync(
+ GetFunctionCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionConcurrencyAsync(
+ GetFunctionConcurrencyRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionConfigurationAsync(
+ string fn,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionConfigurationAsync(
+ GetFunctionConfigurationRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionEventInvokeConfigAsync(
+ GetFunctionEventInvokeConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionRecursionConfigAsync(
+ GetFunctionRecursionConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionUrlConfigAsync(
+ GetFunctionUrlConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetLayerVersionAsync(
+ GetLayerVersionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetLayerVersionByArnAsync(
+ GetLayerVersionByArnRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetLayerVersionPolicyAsync(
+ GetLayerVersionPolicyRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetPolicyAsync(string fn, CancellationToken ct = default) =>
+ throw new NotImplementedException();
+
+ public Task GetPolicyAsync(
+ GetPolicyRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetProvisionedConcurrencyConfigAsync(
+ GetProvisionedConcurrencyConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetRuntimeManagementConfigAsync(
+ GetRuntimeManagementConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task InvokeWithResponseStreamAsync(
+ InvokeWithResponseStreamRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListAliasesAsync(
+ ListAliasesRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListCodeSigningConfigsAsync(
+ ListCodeSigningConfigsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListEventSourceMappingsAsync(
+ ListEventSourceMappingsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListFunctionEventInvokeConfigsAsync(
+ ListFunctionEventInvokeConfigsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListFunctionsAsync(CancellationToken ct = default) =>
+ throw new NotImplementedException();
+
+ public Task ListFunctionsAsync(
+ ListFunctionsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListFunctionsByCodeSigningConfigAsync(
+ ListFunctionsByCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListFunctionUrlConfigsAsync(
+ ListFunctionUrlConfigsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListLayersAsync(
+ ListLayersRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListLayerVersionsAsync(
+ ListLayerVersionsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListProvisionedConcurrencyConfigsAsync(
+ ListProvisionedConcurrencyConfigsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListTagsAsync(
+ ListTagsRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListVersionsByFunctionAsync(
+ ListVersionsByFunctionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PublishLayerVersionAsync(
+ PublishLayerVersionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PublishVersionAsync(
+ PublishVersionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutFunctionCodeSigningConfigAsync(
+ PutFunctionCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutFunctionConcurrencyAsync(
+ PutFunctionConcurrencyRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutFunctionEventInvokeConfigAsync(
+ PutFunctionEventInvokeConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutFunctionRecursionConfigAsync(
+ PutFunctionRecursionConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutProvisionedConcurrencyConfigAsync(
+ PutProvisionedConcurrencyConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutRuntimeManagementConfigAsync(
+ PutRuntimeManagementConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task RemoveLayerVersionPermissionAsync(
+ RemoveLayerVersionPermissionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task RemovePermissionAsync(
+ string fn,
+ string sid,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task RemovePermissionAsync(
+ RemovePermissionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task TagResourceAsync(
+ TagResourceRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UntagResourceAsync(
+ UntagResourceRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateAliasAsync(
+ UpdateAliasRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateCodeSigningConfigAsync(
+ UpdateCodeSigningConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateEventSourceMappingAsync(
+ UpdateEventSourceMappingRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateFunctionCodeAsync(
+ UpdateFunctionCodeRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateFunctionConfigurationAsync(
+ UpdateFunctionConfigurationRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateFunctionEventInvokeConfigAsync(
+ UpdateFunctionEventInvokeConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateFunctionUrlConfigAsync(
+ UpdateFunctionUrlConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CheckpointDurableExecutionAsync(
+ CheckpointDurableExecutionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task CreateCapacityProviderAsync(
+ CreateCapacityProviderRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task DeleteCapacityProviderAsync(
+ DeleteCapacityProviderRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetCapacityProviderAsync(
+ GetCapacityProviderRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetDurableExecutionAsync(
+ GetDurableExecutionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetDurableExecutionHistoryAsync(
+ GetDurableExecutionHistoryRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetDurableExecutionStateAsync(
+ GetDurableExecutionStateRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task GetFunctionScalingConfigAsync(
+ GetFunctionScalingConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListCapacityProvidersAsync(
+ ListCapacityProvidersRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListDurableExecutionsByFunctionAsync(
+ ListDurableExecutionsByFunctionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task ListFunctionVersionsByCapacityProviderAsync(
+ ListFunctionVersionsByCapacityProviderRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task PutFunctionScalingConfigAsync(
+ PutFunctionScalingConfigRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task SendDurableExecutionCallbackFailureAsync(
+ SendDurableExecutionCallbackFailureRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task SendDurableExecutionCallbackHeartbeatAsync(
+ SendDurableExecutionCallbackHeartbeatRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task SendDurableExecutionCallbackSuccessAsync(
+ SendDurableExecutionCallbackSuccessRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task StopDurableExecutionAsync(
+ StopDurableExecutionRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public Task UpdateCapacityProviderAsync(
+ UpdateCapacityProviderRequest r,
+ CancellationToken ct = default
+ ) => throw new NotImplementedException();
+
+ public void Dispose() { }
+ }
+
+ #endregion
+}
diff --git a/tests/Trax.Scheduler.Tests/UnitTests/LambdaRunExecutorTests.cs b/tests/Trax.Scheduler.Tests/UnitTests/LambdaRunExecutorTests.cs
index 43cb4b7..85eda40 100644
--- a/tests/Trax.Scheduler.Tests/UnitTests/LambdaRunExecutorTests.cs
+++ b/tests/Trax.Scheduler.Tests/UnitTests/LambdaRunExecutorTests.cs
@@ -1,7 +1,9 @@
+using System.Net;
using System.Text;
using System.Text.Json;
using Amazon.Lambda;
using Amazon.Lambda.Model;
+using Amazon.Runtime;
using FluentAssertions;
using Microsoft.Extensions.Logging.Abstractions;
using Trax.Core.Exceptions;
@@ -288,6 +290,49 @@ await executor.ExecuteAsync(
#endregion
+ #region Retry
+
+ [Test]
+ public async Task ExecuteAsync_ThrottleThenSuccess_RetriesTransparentlyAndReturnsResult()
+ {
+ // Arrange
+ var response = new RemoteRunResponse(
+ MetadataId: 42,
+ OutputJson: """{"value":"ok","count":1}""",
+ OutputType: typeof(TestRunOutput).FullName
+ );
+ var client = CreateMockClient(response);
+ client.ExceptionsBeforeSuccess.Enqueue(
+ new AmazonServiceException("Throttled") { StatusCode = HttpStatusCode.TooManyRequests }
+ );
+ var executor = new LambdaRunExecutor(
+ client,
+ new LambdaRunOptions
+ {
+ FunctionName = "my-runner",
+ Retry = new LambdaRetryOptions
+ {
+ MaxRetries = 3,
+ BaseDelay = TimeSpan.FromMilliseconds(1),
+ },
+ },
+ NullLogger.Instance
+ );
+
+ // Act
+ var result = await executor.ExecuteAsync(
+ "My.Train",
+ new TestRunInput { Name = "test" },
+ typeof(TestRunOutput)
+ );
+
+ // Assert
+ result.MetadataId.Should().Be(42);
+ client.AllRequests.Should().HaveCount(2); // 1 throttled + 1 success
+ }
+
+ #endregion
+
#region Helpers
private static LambdaRunExecutor CreateExecutor(MockLambdaClient client) =>
@@ -328,8 +373,10 @@ public record TestRunOutput
private class MockLambdaClient : IAmazonLambda
{
public InvokeRequest? LastRequest { get; private set; }
+ public List AllRequests { get; } = [];
public string? FunctionError { get; set; }
public MemoryStream? ResponsePayload { get; set; }
+ public Queue ExceptionsBeforeSuccess { get; } = new();
public Amazon.Runtime.IClientConfig Config => throw new NotImplementedException();
@@ -342,6 +389,11 @@ public Task InvokeAsync(
{
cancellationToken.ThrowIfCancellationRequested();
+ AllRequests.Add(request);
+
+ if (ExceptionsBeforeSuccess.Count > 0)
+ throw ExceptionsBeforeSuccess.Dequeue();
+
LastRequest = request;
var payload = ResponsePayload ?? new MemoryStream(Encoding.UTF8.GetBytes("{}"));