diff --git a/.pylintrc b/.pylintrc index 780195824..8f3bb63eb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -29,5 +29,4 @@ disable = no-value-for-parameter, unexpected-keyword-arg, inconsistent-return-statements, - duplicate-code, - attribute-defined-outside-init, + duplicate-code diff --git a/aikido_zen/background_process/cloud_connection_manager/__init__.py b/aikido_zen/background_process/cloud_connection_manager/__init__.py index 579fa0ef7..e4fd38f54 100644 --- a/aikido_zen/background_process/cloud_connection_manager/__init__.py +++ b/aikido_zen/background_process/cloud_connection_manager/__init__.py @@ -35,13 +35,7 @@ def __init__(self, block, api, token, serverless): self.token = token # Should be instance of the Token class! self.routes = Routes(200) self.hostnames = Hostnames(200) - self.conf = ServiceConfig( - endpoints=[], - last_updated_at=-1, # Has not been updated yet - blocked_uids=[], - bypassed_ips=[], - received_any_stats=True, - ) + self.conf = ServiceConfig() self.firewall_lists = FirewallLists() self.rate_limiter = RateLimiter( max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py index 7d7f729e8..fa0d855aa 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py @@ -15,10 +15,11 @@ def update_service_config(connection_manager, res): logger.debug("Updating blocking, setting blocking to : %s", res["block"]) connection_manager.block = bool(res["block"]) - connection_manager.conf.update( - endpoints=res.get("endpoints", []), - last_updated_at=res.get("configUpdatedAt", get_unixtime_ms()), - blocked_uids=res.get("blockedUserIds", []), - bypassed_ips=res.get("allowedIPAddresses", []), - received_any_stats=res.get("receivedAnyStats", True), + connection_manager.conf.set_endpoints(res.get("endpoints", [])) + connection_manager.conf.set_last_updated_at( + res.get("configUpdatedAt", get_unixtime_ms()) ) + connection_manager.conf.set_blocked_user_ids(res.get("blockedUserIds", [])) + connection_manager.conf.set_bypassed_ips(res.get("allowedIPAddresses", [])) + if res.get("receivedAnyStats", True): + connection_manager.conf.enable_received_stats() diff --git a/aikido_zen/background_process/service_config.py b/aikido_zen/background_process/service_config.py index c4eaa2320..b635d24bc 100644 --- a/aikido_zen/background_process/service_config.py +++ b/aikido_zen/background_process/service_config.py @@ -6,40 +6,16 @@ from aikido_zen.helpers.match_endpoints import match_endpoints -# noinspection PyAttributeOutsideInit class ServiceConfig: - """Class holding the config of the connection_manager""" - def __init__( - self, - endpoints, - last_updated_at: int, - blocked_uids, - bypassed_ips, - received_any_stats: bool, - ): - # Init the class using update function : - self.update( - endpoints, last_updated_at, blocked_uids, bypassed_ips, received_any_stats - ) - - def update( - self, - endpoints, - last_updated_at: int, - blocked_uids, - bypassed_ips, - received_any_stats: bool, - ): - self.last_updated_at = last_updated_at - self.received_any_stats = bool(received_any_stats) - self.blocked_uids = set(blocked_uids) - self.set_endpoints(endpoints) - self.set_bypassed_ips(bypassed_ips) + def __init__(self): + self.endpoints = [] + self.bypassed_ips = IPMatcher() + self.blocked_uids = set() + self.last_updated_at = -1 + self.received_any_stats = False def set_endpoints(self, endpoints): - """Sets non-graphql endpoints""" - self.endpoints = [ endpoint for endpoint in endpoints if not endpoint.get("graphql") ] @@ -66,7 +42,7 @@ def get_endpoints(self, route_metadata): return match_endpoints(route_metadata, self.endpoints) def set_bypassed_ips(self, bypassed_ips): - """Creates an IPMatcher from the given bypassed ip set""" + """Creates a new IPMatcher from the given bypassed ip set""" self.bypassed_ips = IPMatcher() for ip in bypassed_ips: self.bypassed_ips.add(ip) @@ -74,3 +50,12 @@ def set_bypassed_ips(self, bypassed_ips): def is_bypassed_ip(self, ip): """Checks if the IP is on the bypass list""" return self.bypassed_ips.has(ip) + + def set_blocked_user_ids(self, blocked_user_ids): + self.blocked_uids = set(blocked_user_ids) + + def enable_received_any_stats(self): + self.received_any_stats = True + + def set_last_updated_at(self, last_updated_at: int): + self.last_updated_at = last_updated_at diff --git a/aikido_zen/background_process/service_config_test.py b/aikido_zen/background_process/service_config_test.py index ce7a44214..5437c3a9e 100644 --- a/aikido_zen/background_process/service_config_test.py +++ b/aikido_zen/background_process/service_config_test.py @@ -4,6 +4,7 @@ def test_service_config_initialization(): + service_config = ServiceConfig() endpoints = [ { "graphql": False, @@ -51,26 +52,33 @@ def test_service_config_initialization(): "force_protection_off": False, }, ] - last_updated_at = "2023-10-01" - service_config = ServiceConfig( - endpoints, - last_updated_at, - ["0", "0", "1", "5"], - ["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"], - True, - ) - # Check that non-GraphQL endpoints are correctly filtered - assert len(service_config.endpoints) == 3 + assert len(service_config.endpoints) == 0 + service_config.set_endpoints(endpoints) + assert ( + len(service_config.endpoints) == 3 + ) # Check that non-GraphQL endpoints are correctly filtered assert service_config.endpoints[0]["route"] == "/v1" assert service_config.endpoints[1]["route"] == "/v3" assert service_config.endpoints[2]["route"] == "/admin" - assert service_config.last_updated_at == last_updated_at + + service_config.set_last_updated_at(37982562953) + assert service_config.last_updated_at == 37982562953 + + assert isinstance(service_config.bypassed_ips, IPMatcher) + service_config.set_bypassed_ips(["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"]) assert isinstance(service_config.bypassed_ips, IPMatcher) assert service_config.bypassed_ips.has("127.0.0.1") assert service_config.bypassed_ips.has("123.1.2.2") assert not service_config.bypassed_ips.has("1.1.1.1") - assert service_config.blocked_uids == set(["1", "0", "5"]) + + assert len(service_config.blocked_uids) == 0 + service_config.set_blocked_user_ids({"0", "0", "1", "5"}) + assert service_config.blocked_uids == {"1", "0", "5"} + + assert not service_config.received_any_stats + service_config.enable_received_any_stats() + assert service_config.received_any_stats == True v1_endpoint = service_config.get_endpoints( { @@ -96,41 +104,9 @@ def test_service_config_initialization(): assert not admin_endpoint["allowedIPAddresses"].has("192.168.0.1") -# Sample data for testing -sample_endpoints = [ - {"url": "http://example.com/api/v1", "graphql": False, "context": "user"}, - {"url": "http://example.com/api/v2", "graphql": True, "context": "admin"}, - {"url": "http://example.com/api/v3", "graphql": False, "context": "guest"}, -] - - -@pytest.fixture -def service_config(): - return ServiceConfig( - endpoints=sample_endpoints, - last_updated_at="2023-10-01T00:00:00Z", - blocked_uids=["user1", "user2"], - bypassed_ips=["192.168.1.1", "10.0.0.1"], - received_any_stats=True, - ) - - -def test_initialization(service_config): - assert len(service_config.endpoints) == 2 # Only non-graphql endpoints - assert service_config.last_updated_at == "2023-10-01T00:00:00Z" - assert isinstance(service_config.bypassed_ips, IPMatcher) - assert service_config.blocked_uids == {"user1", "user2"} - - def test_ip_blocking(): - config = ServiceConfig( - endpoints=sample_endpoints, - last_updated_at="2023-10-01T00:00:00Z", - blocked_uids=["user1", "user2"], - bypassed_ips=["192.168.1.1", "10.0.0.0/16", "::1/128"], - received_any_stats=True, - ) - + config = ServiceConfig() + config.set_bypassed_ips(["192.168.1.1", "10.0.0.0/16", "::1/128"]) assert config.is_bypassed_ip("192.168.1.1") assert config.is_bypassed_ip("10.0.0.1") assert config.is_bypassed_ip("10.0.1.2") @@ -142,38 +118,39 @@ def test_ip_blocking(): def test_service_config_with_empty_allowlist(): - endpoints = [ - { - "graphql": False, - "method": "GET", - "route": "/admin", - "rate_limiting": { - "enabled": False, - "max_requests": 10, - "window_size_in_ms": 1000, - }, - "allowedIPAddresses": [], - "force_protection_off": False, - }, - ] - last_updated_at = "2023-10-01" - service_config = ServiceConfig( - endpoints, - last_updated_at, - ["0", "0", "1", "5"], - ["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"], - True, - ) + service_config = ServiceConfig() # Check that non-GraphQL endpoints are correctly filtered + service_config.set_endpoints( + [ + { + "graphql": False, + "method": "GET", + "route": "/admin", + "rate_limiting": { + "enabled": False, + "max_requests": 10, + "window_size_in_ms": 1000, + }, + "allowedIPAddresses": [], + "force_protection_off": False, + }, + ] + ) assert len(service_config.endpoints) == 1 assert service_config.endpoints[0]["route"] == "/admin" - assert service_config.last_updated_at == last_updated_at + + service_config.set_last_updated_at(29839537) + assert service_config.last_updated_at == 29839537 + + service_config.set_blocked_user_ids({"0", "0", "1", "5"}) + assert service_config.blocked_uids == {"1", "0", "5"} + + service_config.set_bypassed_ips(["127.0.0.1", "123.1.2.0/24", "132.1.0.0/16"]) assert isinstance(service_config.bypassed_ips, IPMatcher) - assert service_config.bypassed_ips.has("127.0.0.1") - assert service_config.bypassed_ips.has("123.1.2.2") - assert not service_config.bypassed_ips.has("1.1.1.1") - assert service_config.blocked_uids == set(["1", "0", "5"]) + assert service_config.is_bypassed_ip("127.0.0.1") + assert service_config.is_bypassed_ip("123.1.2.2") + assert not service_config.is_bypassed_ip("1.1.1.1") admin_endpoint = service_config.get_endpoints( { diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index bb8eeb3b3..3bbe62735 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -50,7 +50,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.parsed_userinput = {} self.xml = {} self.outgoing_req_redirects = [] - self.set_body(body) + self.body = Context.parse_body_object(body) self.headers: Headers = Headers() self.cookies = {} self.query = {} @@ -107,26 +107,28 @@ def set_cookies(self, cookies): self.cookies = cookies def set_body(self, body): - try: - self.set_body_internal(body) - except Exception as e: - logger.debug("Exception occurred whilst setting body: %s", e) + self.body = Context.parse_body_object(body) - def set_body_internal(self, body): + @staticmethod + def parse_body_object(body): """Sets the body and checks if it's possibly JSON""" - self.body = body - if isinstance(self.body, (str, bytes)) and len(body) == 0: - # Make sure that empty bodies like b"" don't get sent. - self.body = None - if isinstance(self.body, bytes): - self.body = self.body.decode("utf-8") # Decode byte input to string. - if not isinstance(self.body, str): - return - if self.body.strip()[0] in ["{", "[", '"']: - # Might be JSON, but might not have been parsed correctly by server because of wrong headers - parsed_body = json.loads(self.body) - if parsed_body: - self.body = parsed_body + try: + if isinstance(body, (str, bytes)) and len(body) == 0: + # Make sure that empty bodies like b"" don't get sent. + return None + if isinstance(body, bytes): + body = body.decode("utf-8") # Decode byte input to string. + if not isinstance(body, str): + return body + if body.strip()[0] in ["{", "[", '"']: + # Might be JSON, but might not have been parsed correctly by server because of wrong headers + parsed_body = json.loads(body) + if parsed_body: + return parsed_body + return body + except Exception as e: + logger.debug("Exception occurred whilst parsing body: %s", e) + return body def get_route_metadata(self): """Returns a route_metadata object""" diff --git a/aikido_zen/ratelimiting/init_test.py b/aikido_zen/ratelimiting/init_test.py index 966d8c7ac..017db8124 100644 --- a/aikido_zen/ratelimiting/init_test.py +++ b/aikido_zen/ratelimiting/init_test.py @@ -17,13 +17,10 @@ def user(): def create_connection_manager(endpoints=[], bypassed_ips=[]): cm = MagicMock() - cm.conf = ServiceConfig( - endpoints=endpoints, - last_updated_at=1, - blocked_uids=[], - bypassed_ips=bypassed_ips, - received_any_stats=True, - ) + cm.conf = ServiceConfig() + cm.conf.set_endpoints(endpoints) + cm.conf.enable_received_any_stats() + cm.conf.set_bypassed_ips(bypassed_ips) cm.rate_limiter = RateLimiter( max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes ) diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index db76717f3..fac92b651 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -146,19 +146,16 @@ def set_context(remote_address, user_agent=""): def create_service_config(): - config = ServiceConfig( - endpoints=[ + config = ServiceConfig() + config.set_endpoints( + [ { "method": "POST", "route": "/posts/:number", "graphql": False, "allowedIPAddresses": ["1.1.1.1", "2.2.2.2", "3.3.3.3"], } - ], - last_updated_at=None, - blocked_uids=set(), - bypassed_ips=[], - received_any_stats=False, + ] ) get_cache().config = config return config diff --git a/aikido_zen/thread/thread_cache.py b/aikido_zen/thread/thread_cache.py index 0072bac3c..10f3f12e1 100644 --- a/aikido_zen/thread/thread_cache.py +++ b/aikido_zen/thread/thread_cache.py @@ -17,11 +17,13 @@ class ThreadCache: """ def __init__(self): + self.config = ServiceConfig() + self.middleware_installed = False + self.routes = Routes(max_size=1000) self.hostnames = Hostnames(200) self.users = Users(1000) self.stats = Statistics() self.ai_stats = AIStatistics() - self.reset() # Initialize values def is_bypassed_ip(self, ip): """Checks the given IP against the list of bypassed ips""" @@ -35,16 +37,9 @@ def get_endpoints(self): return self.config.endpoints def reset(self): - """Empties out all values of the cache""" - self.routes = Routes(max_size=1000) - self.config = ServiceConfig( - endpoints=[], - blocked_uids=set(), - bypassed_ips=[], - last_updated_at=-1, - received_any_stats=False, - ) self.middleware_installed = False + self.config = ServiceConfig() + self.routes.clear() self.hostnames.clear() self.users.clear() self.stats.clear() diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index 83b484643..325d34d20 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -150,26 +150,27 @@ def increment_in_thread(): def test_parses_routes_correctly(mock_get_comms, thread_cache: ThreadCache): """Test renewing the cache multiple times if TTL has expired.""" mock_get_comms.return_value = MagicMock() + service_config = ServiceConfig() + service_config.set_endpoints( + [ + { + "graphql": False, + "method": "POST", + "route": "/v2", + "rate_limiting": { + "enabled": False, + }, + "force_protection_off": False, + } + ] + ) + service_config.set_bypassed_ips(["192.168.1.1"]) + service_config.set_blocked_user_ids({"user123"}) + service_config.enable_received_any_stats() mock_get_comms.return_value.send_data_to_bg_process.return_value = { "success": True, "data": { - "config": ServiceConfig( - endpoints=[ - { - "graphql": False, - "method": "POST", - "route": "/v2", - "rate_limiting": { - "enabled": False, - }, - "force_protection_off": False, - } - ], - bypassed_ips=["192.168.1.1"], - blocked_uids={"user123"}, - last_updated_at=-1, - received_any_stats=True, - ), + "config": service_config, "routes": { "POST:/body": { "method": "POST",