Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/CommonLib/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
using SharpHoundCommonLib.Processors;
using Microsoft.Win32;
using System.Threading.Tasks;
using System.Threading;

namespace SharpHoundCommonLib {
public static class Helpers {
private static readonly HashSet<string> Groups = new() { "268435456", "268435457", "536870912", "536870913" };
private static readonly HashSet<string> Computers = new() { "805306369" };
private static readonly HashSet<string> Users = new() { "805306368", "805306370" };
private static readonly double MaxTimeSpanTicks = (double)TimeSpan.MaxValue.Ticks - 1_000;

private static readonly Regex DCReplaceRegex = new("DC=", RegexOptions.IgnoreCase | RegexOptions.Compiled);
private static readonly Regex SPNRegex = new(@".*\/.*", RegexOptions.Compiled);
Expand Down Expand Up @@ -318,15 +320,28 @@ public static string DumpDirectoryObject(this IDirectoryObject directoryObject)
return builder.ToString();
}

public static TimeSpan BackoffWithDecorrelatedJitter(int attempt, TimeSpan baseDelay, TimeSpan maxDelay) {
// Decorrelated Jitter Backoff - see https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
var temp = Math.Min(maxDelay.Ticks, baseDelay.Ticks * (attempt * attempt));
temp = temp / 2 + RandomUtils.Between(0, temp / 2);
var ticksToDelay = Math.Min(maxDelay.Ticks, RandomUtils.Between(baseDelay.Ticks, temp * 3));

// This ensures that a TimeSpan can be created with the ticks amount as TimeSpan uses a long.
return double.IsInfinity(ticksToDelay) ? TimeSpan.FromTicks((long)MaxTimeSpanTicks) :
TimeSpan.FromTicks((long)Math.Min(MaxTimeSpanTicks, ticksToDelay));
}

/// <summary>
/// Attempt an action a number of times, quietly eating a specific exception until the last attempt if it throws.
/// </summary>
/// <param name="action"></param>
/// <param name="retryCount"></param>
/// <param name="logger"></param>
public static async Task RetryOnException<T>(Func<Task> action, int retryCount, ILogger logger = null) where T : Exception {
public static async Task RetryOnException<T>(Func<Task> action, int retryCount, TimeSpan? baseDelay = null, TimeSpan? maxDelay = null, ILogger logger = null) where T : Exception {
int attempt = 0;
bool success = false;
baseDelay ??= TimeSpan.FromSeconds(1);
maxDelay ??= TimeSpan.FromSeconds(30);
do {
try {
await action();
Expand All @@ -337,9 +352,33 @@ public static async Task RetryOnException<T>(Func<Task> action, int retryCount,
logger?.LogDebug(e, "Exception caught, retrying attempt {Attempt}", attempt);
if (attempt >= retryCount)
throw;

var delay = BackoffWithDecorrelatedJitter(attempt, baseDelay.Value, maxDelay.Value);
await Task.Delay(delay);
}
} while (!success && attempt < retryCount);
}

public static async Task<U> RetryOnException<T, U>(Func<U> action, int retryCount, TimeSpan? baseDelay = null, TimeSpan? maxDelay = null, ILogger logger = null) where T : Exception {
int attempt = 0;
baseDelay ??= TimeSpan.FromSeconds(1);
maxDelay ??= TimeSpan.FromSeconds(30);
do {
try {
return action();
}
catch (T e) {
attempt++;
logger?.LogDebug(e, "Exception caught, retrying attempt {Attempt}", attempt);
if (attempt >= retryCount)
throw;
var delay = BackoffWithDecorrelatedJitter(attempt, baseDelay.Value, maxDelay.Value);
await Task.Delay(delay);
}
} while (attempt < retryCount);

throw new InvalidOperationException($"You really shouldn't be here, {nameof(RetryOnException)} isn't working as intended.");
}
}

public class ParsedGPLink {
Expand Down
11 changes: 4 additions & 7 deletions src/CommonLib/LdapConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ internal class LdapConnectionPool : IDisposable {
private const int MaxRetries = 3;
private static readonly ConcurrentDictionary<string, NetAPIStructs.DomainControllerInfo?> DCInfoCache = new();

// Tracks domains we know we've determined we shouldn't try to connect to
private static readonly ConcurrentHashSet _excludedDomains = new();

public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config,
IPortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) {
_connections = new ConcurrentBag<LdapConnectionWrapper>();
Expand Down Expand Up @@ -693,7 +690,7 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)>
GetConnectionAsync() {
if (_excludedDomains.Contains(_identifier)) {
if (LdapUtils.IsExcludedDomain(_identifier)) {
return (false, null, $"Identifier {_identifier} excluded for connection attempt");
}

Expand Down Expand Up @@ -727,7 +724,7 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll

public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)>
GetGlobalCatalogConnectionAsync() {
if (_excludedDomains.Contains(_identifier)) {
if (LdapUtils.IsExcludedDomain(_identifier)) {
return (false, null, $"Identifier {_identifier} excluded for connection attempt");
}

Expand Down Expand Up @@ -813,7 +810,7 @@ await CreateLdapConnection(tempDomainName, globalCatalog) is (true, var connecti
_log.LogDebug(
"Could not get domain object from GetDomain, unable to create ldap connection for domain {Domain}",
_identifier);
_excludedDomains.Add(_identifier);
LdapUtils.AddExcludedDomain(_identifier);
return (false, null, "Unable to get domain object for further strategies");
}

Expand Down Expand Up @@ -852,7 +849,7 @@ await CreateLdapConnection(tempDomainName, globalCatalog) is (true, var connecti
catch (Exception e) {
_log.LogInformation(e, "We will not be able to connect to domain {Domain} by any strategy, leaving it.",
_identifier);
_excludedDomains.Add(_identifier);
LdapUtils.AddExcludedDomain(_identifier);
}
Comment on lines +852 to 853
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Don’t blanket‑exclude the domain on any exception at the end of CreateNewConnection.

Catching Exception and permanently excluding can mask auth/config issues and reduce resiliency. Restrict exclusion to specific transient/network failures or remove it here and rely on upstream retry/strategy logic.

Apply this change:

-                LdapUtils.AddExcludedDomain(_identifier);
+                // Avoid permanent exclusion on generic exceptions; rely on retries/strategies.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
LdapUtils.AddExcludedDomain(_identifier);
}
// Avoid permanent exclusion on generic exceptions; rely on retries/strategies.
}
🤖 Prompt for AI Agents
In src/CommonLib/LdapConnectionPool.cs around lines 852-853, the code currently
catches all Exception and calls LdapUtils.AddExcludedDomain(_identifier), which
can hide auth/configuration errors and permanently blacklist domains; change
this to not blanket-exclude on any exception by replacing the broad catch with
targeted handling: catch only transient/network-related exceptions (e.g.,
LdapException with transient/error codes, SocketException, TimeoutException or
whatever concrete exceptions your LDAP client throws) and call
LdapUtils.AddExcludedDomain(_identifier) only in those catches; for other
exceptions rethrow after logging (or log and let upstream retry/strategy logic
handle it) and avoid swallowing Exception; ensure logging includes the exception
details and update any unit/integration tests to reflect the new behavior.


return (false, null, "All attempted connections failed");
Expand Down
32 changes: 26 additions & 6 deletions src/CommonLib/LdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public class LdapUtils : ILdapUtils {
private static ConcurrentDictionary<string, Domain> _domainCache = new();
private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase);
private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase);

// Tracks Domains we know we've determined we shouldn't try to connect to
private static ConcurrentHashSet _excludedDomains = new(StringComparer.OrdinalIgnoreCase);

private static readonly ConcurrentDictionary<string, string> DomainToForestCache =
new(StringComparer.OrdinalIgnoreCase);
Expand All @@ -50,7 +53,7 @@ private readonly ConcurrentDictionary<string, string>
private readonly ILogger _log;
private readonly IPortScanner _portScanner;
private readonly NativeMethods _nativeMethods;
private readonly string _nullCacheKey = Guid.NewGuid().ToString();
private static readonly string _nullCacheKey = Guid.NewGuid().ToString();
private static readonly Regex SIDRegex = new(@"^(S-\d+-\d+-\d+-\d+-\d+-\d+)(-\d+)?$");

private readonly string[] _translateNames = { "Administrator", "admin" };
Expand Down Expand Up @@ -506,20 +509,28 @@ public bool GetDomain(string domainName, out Domain domain) {
: new DirectoryContext(DirectoryContextType.Domain);

// Blocking External Call
domain = Domain.GetDomain(context);
domain = Helpers.RetryOnException<ActiveDirectoryObjectNotFoundException, Domain>(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult();
if (domain == null) return false;
_domainCache.TryAdd(cacheKey, domain);
return true;
}
catch (Exception e) {
// The Static GetDomain Function ran into an issue requiring to exclude a domain as it would continuously
// try to connect to a domain that it could not connect to. This method may also need the same logic.
_log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName);
domain = null;
return false;
}
}

public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) {
var cacheKey = domainName ?? _nullCacheKey;
if (_domainCache.TryGetValue(domainName, out domain)) return true;
if (IsExcludedDomain(domainName)) {
Logging.Logger.LogDebug("Domain: {DomainName} has been excluded for collection. Skipping", domainName);
domain = null;
return false;
}

try {
DirectoryContext context;
Expand All @@ -535,14 +546,17 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai
: new DirectoryContext(DirectoryContextType.Domain);

// Blocking External Call
domain = Domain.GetDomain(context);
domain = Helpers.RetryOnException<ActiveDirectoryObjectNotFoundException, Domain>(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult();
if (domain == null) return false;
_domainCache.TryAdd(domainName, domain);
_domainCache.TryAdd(cacheKey, domain);
return true;
}
catch (Exception e) {
Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName,
Logging.Logger.LogDebug("Static GetDomain call failed, adding to exclusion, for domain {DomainName}: {Error}", domainName,
e.Message);
// If a domain cannot be contacted, this will exclude the domain so that it does not continuously try to connect, and
// cause more timeouts.
AddExcludedDomain(cacheKey);
domain = null;
return false;
}
Expand All @@ -565,11 +579,13 @@ public bool GetDomain(out Domain domain) {
: new DirectoryContext(DirectoryContextType.Domain);

// Blocking External Call
domain = Domain.GetDomain(context);
domain = Helpers.RetryOnException<ActiveDirectoryObjectNotFoundException, Domain>(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult();
_domainCache.TryAdd(_nullCacheKey, domain);
return true;
}
catch (Exception e) {
// The Static GetDomain Function ran into an issue requiring to exclude a domain as it would continuously
// try to connect to a domain that it could not connect to. This method may also need the same logic.
_log.LogDebug(e, "GetDomain call failed for blank domain");
domain = null;
return false;
Expand Down Expand Up @@ -1129,6 +1145,7 @@ public void ResetUtils() {
_domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase);
_connectionPool?.Dispose();
_connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner);
_excludedDomains = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase);
}

private IDirectoryObject CreateDirectoryEntry(string path) {
Expand All @@ -1143,6 +1160,9 @@ public void Dispose() {
_connectionPool?.Dispose();
}

public static bool IsExcludedDomain(string domain) => _excludedDomains.Contains(domain);
public static void AddExcludedDomain(string domain) => _excludedDomains.Add(domain);

Comment on lines +1163 to +1165
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard exclusion helpers against null/empty inputs.

Current code will throw ArgumentNullException on null domain. Also avoid polluting the set with empty strings.

Apply this diff:

-public static bool IsExcludedDomain(string domain) => _excludedDomains.Contains(domain);
-public static void AddExcludedDomain(string domain) => _excludedDomains.Add(domain);
+public static bool IsExcludedDomain(string domain) =>
+    !string.IsNullOrWhiteSpace(domain) && _excludedDomains.Contains(domain);
+public static void AddExcludedDomain(string domain) {
+    if (!string.IsNullOrWhiteSpace(domain))
+        _excludedDomains.Add(domain);
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
public static bool IsExcludedDomain(string domain) => _excludedDomains.Contains(domain);
public static void AddExcludedDomain(string domain) => _excludedDomains.Add(domain);
public static bool IsExcludedDomain(string domain) => !string.IsNullOrWhiteSpace(domain) && _excludedDomains.Contains(domain);
public static void AddExcludedDomain(string domain)
{
if (!string.IsNullOrWhiteSpace(domain))
_excludedDomains.Add(domain);
}
🤖 Prompt for AI Agents
In src/CommonLib/LdapUtils.cs around lines 1161 to 1163, the helpers don’t guard
against null or empty domain values which causes ArgumentNullException and
allows empty strings into the set; update IsExcludedDomain to return false when
domain is null or empty (optionally trim) and update AddExcludedDomain to no-op
when domain is null or empty (optionally trim and normalize case before adding)
so the set is never polluted and calls are safe.

internal static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType,
string[] objectClasses, int flags, out Label type) {
type = Label.Base;
Expand Down
20 changes: 20 additions & 0 deletions src/CommonLib/RandomUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Threading;

namespace SharpHoundCommonLib;

public static class RandomUtils {
private static readonly ThreadLocal<Random> Random = new(() => new Random());

public static long NextLong() => LongRandom(0, long.MaxValue);

private static long LongRandom(long min, long max) {
var buf = new byte[8];
Random.Value.NextBytes(buf);
var longRand = BitConverter.ToInt64(buf, 0);
return Math.Abs(longRand % (max - min)) + min;
}

public static double Between(double minValue, double maxValue) => Random.Value.NextDouble() * (maxValue - minValue) + minValue;
public static long Between(long minValue, long maxValue) => LongRandom(minValue, maxValue);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Between(long): enforce valid range and remove bias.

Current implementation inherits LongRandom’s overflow/bias issues.

Apply this diff:

-    public static long Between(long minValue, long maxValue) => LongRandom(minValue, maxValue);
+    public static long Between(long minValue, long maxValue)
+    {
+        if (maxValue <= minValue)
+            throw new ArgumentOutOfRangeException(nameof(maxValue), "maxValue must be greater than minValue.");
+        return RandomNumberGenerator.GetInt64(minValue, maxValue); // [min, max)
+    }

Also consider documenting the inclusive/exclusive semantics ([min, max)) for callers.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
public static long Between(long minValue, long maxValue) => LongRandom(minValue, maxValue);
public static long Between(long minValue, long maxValue)
{
if (maxValue <= minValue)
throw new ArgumentOutOfRangeException(nameof(maxValue), "maxValue must be greater than minValue.");
return RandomNumberGenerator.GetInt64(minValue, maxValue); // [min, max)
}
🤖 Prompt for AI Agents
In src/CommonLib/RandomUtils.cs around line 18, replace the current one-liner
with a safe implementation that validates the range and avoids bias/overflow:
throw an ArgumentOutOfRangeException if minValue >= maxValue, then use an
unbiased provider such as
System.Security.Cryptography.RandomNumberGenerator.GetInt64(minValue, maxValue)
(or Random.Shared.NextInt64 when targeting a runtime that provides it) to return
a value in [min, max); and add a brief comment documenting the
inclusive/exclusive semantics ([min, max)) for callers.

}
25 changes: 1 addition & 24 deletions test/unit/LdapConnectionPoolTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,16 @@

public class LdapConnectionPoolTest
{
private static void AddExclusionDomain(string identifier) {
var excludedDomainsField = typeof(LdapConnectionPool)
.GetField("_excludedDomains", BindingFlags.Static | BindingFlags.NonPublic);

var excludedDomains = (ConcurrentHashSet)excludedDomainsField.GetValue(null);

excludedDomains.Add(identifier);
}

[Fact]
public async Task LdapConnectionPool_ExcludedDomains_ShouldExitEarly()
public async Task LdapConnectionPool_Static_GetDomain_Add_To_ExcludedDomains_ShouldExitEarly()
{
var mockLogger = new Mock<ILogger>();
var ldapConfig = new LdapConfig();
var connectionPool = new ConnectionPoolManager(ldapConfig, mockLogger.Object);

AddExclusionDomain("excludedDomain.com");
var connectAttempt = await connectionPool.TestDomainConnection("excludedDomain.com", false);

Assert.False(connectAttempt.Success);
Assert.Contains("excluded for connection attempt", connectAttempt.Message);
}

[Fact]
public async Task LdapConnectionPool_ExcludedDomains_NonExcludedShouldntExit()
{
var mockLogger = new Mock<ILogger>();
var ldapConfig = new LdapConfig();
var connectionPool = new ConnectionPoolManager(ldapConfig, mockLogger.Object);

AddExclusionDomain("excludedDomain.com");
var connectAttempt = await connectionPool.TestDomainConnection("perfectlyValidDomain.com", false);

Assert.DoesNotContain("excluded for connection attempt", connectAttempt.Message);
}
}
17 changes: 17 additions & 0 deletions test/unit/TimeoutTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,21 @@ public async Task ExecuteWithTimeout_Task_T_ParentTokenCancel() {
Assert.False(result.IsSuccess);
Assert.Equal("Cancellation requested", result.Error);
}

[Theory]
[InlineData(0, 2, 30, 2, 6)]
[InlineData(5, 2, 200, 1, 192)]
[InlineData(5, 5, 500, 5, 480)]
[InlineData(0, 2, 1, 1, 1)]
[InlineData(5, 2, 1, 1, 1)]
[InlineData(5, 2, 2, 2, 2)]
[InlineData(5, 30, 30, 30, 30)]
public void DecorrelatedTimeSpan_BetweenExpected(int attempt, int baseDelayValue, int maxDelayValue, double expectedLowerBound, double expectedUpperBound) {
var baseDelay = TimeSpan.FromTicks(baseDelayValue);
var maxDelay = TimeSpan.FromTicks(maxDelayValue);
for (var trials = 0; trials < 500; trials++) {
var delay = SharpHoundCommonLib.Helpers.BackoffWithDecorrelatedJitter(attempt, baseDelay, maxDelay);
Assert.InRange(delay.Ticks, expectedLowerBound, expectedUpperBound);
}
}
}
Loading