-
Notifications
You must be signed in to change notification settings - Fork 52
[BED-6389] Add retry logic to GetDomains to make more resilient to network hiccups #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v4
Are you sure you want to change the base?
Changes from all commits
3d1d6a8
ed71778
7405405
795f1f3
e606ac4
9d17eb9
c43c449
257fe74
8d91bb6
1b32498
a4d369f
e9bea59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -30,6 +30,9 @@ public class LdapUtils : ILdapUtils { | |||||||||||||||||
private static ConcurrentDictionary<string, Domain> _domainCache = new(); | ||||||||||||||||||
private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase); | ||||||||||||||||||
private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase); | ||||||||||||||||||
|
||||||||||||||||||
// Tracks Domains we know we've determined we shouldn't try to connect to | ||||||||||||||||||
private static ConcurrentHashSet _excludedDomains = new(StringComparer.OrdinalIgnoreCase); | ||||||||||||||||||
|
||||||||||||||||||
private static readonly ConcurrentDictionary<string, string> DomainToForestCache = | ||||||||||||||||||
new(StringComparer.OrdinalIgnoreCase); | ||||||||||||||||||
|
@@ -50,7 +53,7 @@ private readonly ConcurrentDictionary<string, string> | |||||||||||||||||
private readonly ILogger _log; | ||||||||||||||||||
private readonly IPortScanner _portScanner; | ||||||||||||||||||
private readonly NativeMethods _nativeMethods; | ||||||||||||||||||
private readonly string _nullCacheKey = Guid.NewGuid().ToString(); | ||||||||||||||||||
private static readonly string _nullCacheKey = Guid.NewGuid().ToString(); | ||||||||||||||||||
private static readonly Regex SIDRegex = new(@"^(S-\d+-\d+-\d+-\d+-\d+-\d+)(-\d+)?$"); | ||||||||||||||||||
|
||||||||||||||||||
private readonly string[] _translateNames = { "Administrator", "admin" }; | ||||||||||||||||||
|
@@ -506,20 +509,28 @@ public bool GetDomain(string domainName, out Domain domain) { | |||||||||||||||||
: new DirectoryContext(DirectoryContextType.Domain); | ||||||||||||||||||
|
||||||||||||||||||
// Blocking External Call | ||||||||||||||||||
domain = Domain.GetDomain(context); | ||||||||||||||||||
domain = Helpers.RetryOnException<ActiveDirectoryObjectNotFoundException, Domain>(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult(); | ||||||||||||||||||
if (domain == null) return false; | ||||||||||||||||||
_domainCache.TryAdd(cacheKey, domain); | ||||||||||||||||||
return true; | ||||||||||||||||||
} | ||||||||||||||||||
catch (Exception e) { | ||||||||||||||||||
// The Static GetDomain Function ran into an issue requiring to exclude a domain as it would continuously | ||||||||||||||||||
// try to connect to a domain that it could not connect to. This method may also need the same logic. | ||||||||||||||||||
_log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName); | ||||||||||||||||||
domain = null; | ||||||||||||||||||
return false; | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) { | ||||||||||||||||||
var cacheKey = domainName ?? _nullCacheKey; | ||||||||||||||||||
if (_domainCache.TryGetValue(domainName, out domain)) return true; | ||||||||||||||||||
if (IsExcludedDomain(domainName)) { | ||||||||||||||||||
Logging.Logger.LogDebug("Domain: {DomainName} has been excluded for collection. Skipping", domainName); | ||||||||||||||||||
domain = null; | ||||||||||||||||||
return false; | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
try { | ||||||||||||||||||
DirectoryContext context; | ||||||||||||||||||
|
@@ -535,14 +546,17 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai | |||||||||||||||||
: new DirectoryContext(DirectoryContextType.Domain); | ||||||||||||||||||
|
||||||||||||||||||
// Blocking External Call | ||||||||||||||||||
domain = Domain.GetDomain(context); | ||||||||||||||||||
domain = Helpers.RetryOnException<ActiveDirectoryObjectNotFoundException, Domain>(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult(); | ||||||||||||||||||
if (domain == null) return false; | ||||||||||||||||||
_domainCache.TryAdd(domainName, domain); | ||||||||||||||||||
_domainCache.TryAdd(cacheKey, domain); | ||||||||||||||||||
return true; | ||||||||||||||||||
} | ||||||||||||||||||
catch (Exception e) { | ||||||||||||||||||
Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName, | ||||||||||||||||||
Logging.Logger.LogDebug("Static GetDomain call failed, adding to exclusion, for domain {DomainName}: {Error}", domainName, | ||||||||||||||||||
e.Message); | ||||||||||||||||||
// If a domain cannot be contacted, this will exclude the domain so that it does not continuously try to connect, and | ||||||||||||||||||
// cause more timeouts. | ||||||||||||||||||
AddExcludedDomain(cacheKey); | ||||||||||||||||||
domain = null; | ||||||||||||||||||
return false; | ||||||||||||||||||
} | ||||||||||||||||||
|
@@ -565,11 +579,13 @@ public bool GetDomain(out Domain domain) { | |||||||||||||||||
: new DirectoryContext(DirectoryContextType.Domain); | ||||||||||||||||||
|
||||||||||||||||||
// Blocking External Call | ||||||||||||||||||
domain = Domain.GetDomain(context); | ||||||||||||||||||
domain = Helpers.RetryOnException<ActiveDirectoryObjectNotFoundException, Domain>(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult(); | ||||||||||||||||||
_domainCache.TryAdd(_nullCacheKey, domain); | ||||||||||||||||||
return true; | ||||||||||||||||||
} | ||||||||||||||||||
catch (Exception e) { | ||||||||||||||||||
// The Static GetDomain Function ran into an issue requiring to exclude a domain as it would continuously | ||||||||||||||||||
// try to connect to a domain that it could not connect to. This method may also need the same logic. | ||||||||||||||||||
_log.LogDebug(e, "GetDomain call failed for blank domain"); | ||||||||||||||||||
domain = null; | ||||||||||||||||||
return false; | ||||||||||||||||||
|
@@ -1129,6 +1145,7 @@ public void ResetUtils() { | |||||||||||||||||
_domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); | ||||||||||||||||||
_connectionPool?.Dispose(); | ||||||||||||||||||
_connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); | ||||||||||||||||||
_excludedDomains = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
private IDirectoryObject CreateDirectoryEntry(string path) { | ||||||||||||||||||
|
@@ -1143,6 +1160,9 @@ public void Dispose() { | |||||||||||||||||
_connectionPool?.Dispose(); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
public static bool IsExcludedDomain(string domain) => _excludedDomains.Contains(domain); | ||||||||||||||||||
public static void AddExcludedDomain(string domain) => _excludedDomains.Add(domain); | ||||||||||||||||||
|
||||||||||||||||||
Comment on lines
+1163
to
+1165
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard exclusion helpers against null/empty inputs. Current code will throw ArgumentNullException on null domain. Also avoid polluting the set with empty strings. Apply this diff: -public static bool IsExcludedDomain(string domain) => _excludedDomains.Contains(domain);
-public static void AddExcludedDomain(string domain) => _excludedDomains.Add(domain);
+public static bool IsExcludedDomain(string domain) =>
+ !string.IsNullOrWhiteSpace(domain) && _excludedDomains.Contains(domain);
+public static void AddExcludedDomain(string domain) {
+ if (!string.IsNullOrWhiteSpace(domain))
+ _excludedDomains.Add(domain);
+} 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||
internal static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType, | ||||||||||||||||||
string[] objectClasses, int flags, out Label type) { | ||||||||||||||||||
type = Label.Base; | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,20 @@ | ||||||||||||||||
using System; | ||||||||||||||||
using System.Threading; | ||||||||||||||||
|
||||||||||||||||
namespace SharpHoundCommonLib; | ||||||||||||||||
|
||||||||||||||||
public static class RandomUtils { | ||||||||||||||||
private static readonly ThreadLocal<Random> Random = new(() => new Random()); | ||||||||||||||||
|
||||||||||||||||
public static long NextLong() => LongRandom(0, long.MaxValue); | ||||||||||||||||
|
||||||||||||||||
private static long LongRandom(long min, long max) { | ||||||||||||||||
var buf = new byte[8]; | ||||||||||||||||
Random.Value.NextBytes(buf); | ||||||||||||||||
var longRand = BitConverter.ToInt64(buf, 0); | ||||||||||||||||
return Math.Abs(longRand % (max - min)) + min; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
public static double Between(double minValue, double maxValue) => Random.Value.NextDouble() * (maxValue - minValue) + minValue; | ||||||||||||||||
public static long Between(long minValue, long maxValue) => LongRandom(minValue, maxValue); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Between(long): enforce valid range and remove bias. Current implementation inherits LongRandom’s overflow/bias issues. Apply this diff: - public static long Between(long minValue, long maxValue) => LongRandom(minValue, maxValue);
+ public static long Between(long minValue, long maxValue)
+ {
+ if (maxValue <= minValue)
+ throw new ArgumentOutOfRangeException(nameof(maxValue), "maxValue must be greater than minValue.");
+ return RandomNumberGenerator.GetInt64(minValue, maxValue); // [min, max)
+ } Also consider documenting the inclusive/exclusive semantics ([min, max)) for callers. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t blanket‑exclude the domain on any exception at the end of CreateNewConnection.
Catching Exception and permanently excluding can mask auth/config issues and reduce resiliency. Restrict exclusion to specific transient/network failures or remove it here and rely on upstream retry/strategy logic.
Apply this change:
📝 Committable suggestion
🤖 Prompt for AI Agents