From 24b187450cab00437e11303ff1b47c2f0450185f Mon Sep 17 00:00:00 2001 From: Nikolaos Kakouros Date: Thu, 13 Feb 2025 12:46:31 +0200 Subject: [PATCH] Handle ids in classes themselves --- maltoolbox/attackgraph/attacker.py | 23 +++-- maltoolbox/attackgraph/attackgraph.py | 123 +++++--------------------- maltoolbox/attackgraph/node.py | 17 +++- tests/attackgraph/test_analyzer.py | 15 ++-- tests/attackgraph/test_attacker.py | 13 +-- tests/attackgraph/test_node.py | 15 ++-- tests/attackgraph/test_query.py | 12 +-- 7 files changed, 74 insertions(+), 144 deletions(-) diff --git a/maltoolbox/attackgraph/attacker.py b/maltoolbox/attackgraph/attacker.py index 645ca95e..af17f219 100644 --- a/maltoolbox/attackgraph/attacker.py +++ b/maltoolbox/attackgraph/attacker.py @@ -15,17 +15,25 @@ class Attacker: + _max_id = -1 + def __init__( self, name: str, - entry_points: set[AttackGraphNode], - reached_attack_steps: set[AttackGraphNode], - attacker_id: Optional[int] = None + entry_points: set[AttackGraphNode] = None, + reached_attack_steps: set[AttackGraphNode] = None, + attacker_id: Optional[int] = None, ): self.name = name - self.entry_points = entry_points - self.reached_attack_steps = reached_attack_steps - self.id = attacker_id + self.entry_points = entry_points or set() + self.reached_attack_steps = reached_attack_steps or set() + + Attacker._max_id = max(Attacker._max_id + 1, attacker_id or 0) + self.id = Attacker._max_id + + @staticmethod + def reset_ids(id=None): + Attacker._max_id = id if id is not None else -1 def to_dict(self) -> dict: attacker_dict: dict = { @@ -60,12 +68,15 @@ def __deepcopy__(self, memo) -> Attacker: if id(self) in memo: return memo[id(self)] + old_max_id = Attacker._max_id + Attacker.reset_ids() copied_attacker = Attacker( name = self.name, attacker_id = self.id, entry_points = set(), reached_attack_steps = set() ) + Attacker.reset_ids(old_max_id) # Remember that self was already copied memo[id(self)] = copied_attacker diff --git a/maltoolbox/attackgraph/attackgraph.py b/maltoolbox/attackgraph/attackgraph.py index 6295142a..b40641e0 100644 --- a/maltoolbox/attackgraph/attackgraph.py +++ b/maltoolbox/attackgraph/attackgraph.py @@ -90,8 +90,6 @@ def __init__(self, lang_graph, model: Optional[Model] = None): self.model = model self.lang_graph = lang_graph - self.next_node_id = 0 - self.next_attacker_id = 0 if self.model is not None: self._generate_graph() @@ -154,10 +152,6 @@ def __deepcopy__(self, memo): copied_attackgraph._full_name_to_node = \ copy.deepcopy(self._full_name_to_node, memo) - # Copy counters - copied_attackgraph.next_node_id = self.next_node_id - copied_attackgraph.next_attacker_id = self.next_attacker_id - return copied_attackgraph def save_to_file(self, filename: str) -> None: @@ -201,15 +195,17 @@ def _from_dict( node_dict['lang_graph_attack_step']) lg_attack_step = lang_graph.assets[lg_asset_name].\ attack_steps[lg_attack_step_name] - ag_node = attack_graph.add_node( + ag_node = AttackGraphNode( lg_attack_step = lg_attack_step, - node_id = node_dict['id'], model_asset = node_asset, + node_id = node_dict['id'], defense_status = node_dict.get('defense_status', None), existence_status = node_dict.get('existence_status', None) ) ag_node.tags = set(node_dict.get('tags', [])) ag_node.extras = node_dict.get('extras', {}) + attack_graph.nodes[ag_node.id] = ag_node + attack_graph._full_name_to_node[ag_node.full_name] = ag_node if node_asset: # Add AttackGraphNode to attack_step_nodes of asset @@ -249,20 +245,14 @@ def _from_dict( _ag_node.parents.add(parent) for attacker in serialized_attackers.values(): - ag_attacker = Attacker( - name = attacker['name'], - entry_points = set(), - reached_attack_steps = set() - ) attack_graph.add_attacker( - attacker = ag_attacker, - attacker_id = int(attacker['id']), + attacker=Attacker(name=attacker['name']), entry_points = [ int(node_id) # Convert to int since they can be strings for node_id in attacker['entry_points'].keys() ], - reached_attack_steps = [ - int(node_id) # Convert to int since they can be strings + reached_attack_steps=[ + int(node_id) # Convert to int since they can be strings for node_id in attacker['reached_attack_steps'].keys() ] ) @@ -323,6 +313,7 @@ def attach_attackers(self) -> None: 'Attach attackers from "%s" model to the graph.', self.model.name ) + Attacker.reset_ids() for attacker_info in self.model.attackers: if not attacker_info.name: @@ -330,11 +321,7 @@ def attach_attackers(self) -> None: logger.error(msg) raise AttackGraphException(msg) - attacker = Attacker( - name = attacker_info.name, - entry_points = set(), - reached_attack_steps = set() - ) + attacker = Attacker(attacker_info.name) self.add_attacker(attacker) for (asset, attack_steps) in attacker_info.entry_points: @@ -520,6 +507,7 @@ def _generate_graph(self) -> None: Generate the attack graph based on the original model instance and the MAL language specification provided at initialization. """ + AttackGraphNode.reset_ids() if not self.model: msg = "Can not generate AttackGraph without model" @@ -580,12 +568,14 @@ def _generate_graph(self) -> None: case _: pass - ag_node = self.add_node( + ag_node = AttackGraphNode( lg_attack_step = attack_step, model_asset = asset, defense_status = defense_status, existence_status = existence_status ) + self.nodes[ag_node.id] = ag_node + self._full_name_to_node[ag_node.full_name] = ag_node attack_step_nodes.append(ag_node) asset.attack_step_nodes = attack_step_nodes @@ -669,64 +659,6 @@ def regenerate_graph(self) -> None: self.attackers = {} self._generate_graph() - def add_node( - self, - lg_attack_step: LanguageGraphAttackStep, - node_id: Optional[int] = None, - model_asset: Optional[ModelAsset] = None, - defense_status: Optional[float] = None, - existence_status: Optional[bool] = None - ) -> AttackGraphNode: - """Create and add a node to the graph - Arguments: - lg_attack_step - the language graph attack step that corresponds - to the attack graph node to create - node_id - id to assign to the newly created node, usually - provided only when loading an existing attack - graph from a file. If not provided the id will - be set to the next highest id available. - model_asset - the model asset that corresponds to the attack - step node. While optional it is highly - recommended that this be provided. It should - only be ommitted if the model which was used to - generate the attack graph is not available when - loading an attack graph from a file. - defese_status - the defense status of the node. Only, relevant - for defense type nodes. A value between 0.0 and - 1.0 is expected. - existence_status - the existence status of the node. Only, relevant - for exist and notExist type nodes. - - Return: - The newly created attack step node. - """ - node_id = node_id if node_id is not None else self.next_node_id - if node_id in self.nodes: - raise ValueError(f'Node index {node_id} already in use.') - self.next_node_id = max(node_id + 1, self.next_node_id) - - if logger.isEnabledFor(logging.DEBUG): - # Avoid running json.dumps when not in debug - logger.debug('Create and add to attackgraph node of type "%s" ' - 'with id:%d.\n' % ( - lg_attack_step.full_name, - node_id - )) - - - node = AttackGraphNode( - node_id = node_id, - lg_attack_step = lg_attack_step, - model_asset = model_asset, - defense_status = defense_status, - existence_status = existence_status - ) - - self.nodes[node_id] = node - self._full_name_to_node[node.full_name] = node - - return node - def remove_node(self, node: AttackGraphNode) -> None: """Remove node from attack graph Arguments: @@ -748,9 +680,8 @@ def remove_node(self, node: AttackGraphNode) -> None: def add_attacker( self, attacker: Attacker, - attacker_id: Optional[int] = None, - entry_points: list[int] = [], - reached_attack_steps: list[int] = [] + entry_points: list[int] = None, + reached_attack_steps: list[int] = None ): """Add an attacker to the graph Arguments: @@ -765,24 +696,10 @@ def add_attacker( """ if logger.isEnabledFor(logging.DEBUG): - # Avoid running json.dumps when not in debug - if attacker_id is not None: - logger.debug('Add attacker "%s" with id:%d.', - attacker.name, - attacker_id - ) - else: - logger.debug('Add attacker "%s" without id.', - attacker.name - ) - - attacker.id = attacker_id or self.next_attacker_id - if attacker.id in self.attackers: - raise ValueError(f'Attacker index {attacker_id} already in use.') + logger.debug('Add attacker "%s" with id:%d.', attacker.name, attacker.id) - self.next_attacker_id = max(attacker.id + 1, self.next_attacker_id) - for node_id in reached_attack_steps: - node = self.nodes[node_id] + for node_id in reached_attack_steps or []: + node = self.get_node_by_id(node_id) if node: attacker.compromise(node) else: @@ -790,8 +707,8 @@ def add_attacker( "in reached attack steps.") logger.error(msg, node_id) raise AttackGraphException(msg % node_id) - for node_id in entry_points: - node = self.nodes[node_id] + for node_id in entry_points or []: + node = self.get_node_by_id(int(node_id)) if node: attacker.entry_points.add(node) else: diff --git a/maltoolbox/attackgraph/node.py b/maltoolbox/attackgraph/node.py index 1d9466f1..449ebb63 100644 --- a/maltoolbox/attackgraph/node.py +++ b/maltoolbox/attackgraph/node.py @@ -16,13 +16,15 @@ class AttackGraphNode: """Node part of AttackGraph""" + _max_id = -1 + def __init__( self, - node_id: int, lg_attack_step: LanguageGraphAttackStep, model_asset: Optional[ModelAsset] = None, defense_status: Optional[float] = None, - existence_status: Optional[bool] = None + existence_status: Optional[bool] = None, + node_id: int = None, ): self.lg_attack_step = lg_attack_step self.name = lg_attack_step.name @@ -31,7 +33,9 @@ def __init__( self.tags = lg_attack_step.tags self.detectors = lg_attack_step.detectors - self.id = node_id + AttackGraphNode._max_id = max(self._max_id + 1, node_id or 0) + self.id = AttackGraphNode._max_id + self.model_asset = model_asset self.defense_status = defense_status self.existence_status = existence_status @@ -43,6 +47,10 @@ def __init__( self.compromised_by: set[Attacker] = set() self.extras: dict = {} + @staticmethod + def reset_ids(id = None): + AttackGraphNode._max_id = id if id is not None else -1 + def to_dict(self) -> dict: """Convert node to dictionary""" node_dict: dict = { @@ -98,11 +106,14 @@ def __deepcopy__(self, memo) -> AttackGraphNode: if id(self) in memo: return memo[id(self)] + old_max_id = AttackGraphNode._max_id + AttackGraphNode.reset_ids() copied_node = AttackGraphNode( node_id = self.id, model_asset = self.model_asset, lg_attack_step = self.lg_attack_step ) + AttackGraphNode.reset_ids(old_max_id) copied_node.tags = copy.deepcopy(self.tags, memo) copied_node.extras = copy.deepcopy(self.extras, memo) diff --git a/tests/attackgraph/test_analyzer.py b/tests/attackgraph/test_analyzer.py index 5bd8d883..5ddbb051 100644 --- a/tests/attackgraph/test_analyzer.py +++ b/tests/attackgraph/test_analyzer.py @@ -1,6 +1,6 @@ """Tests for analyzers""" -from maltoolbox.attackgraph.attackgraph import AttackGraph +from maltoolbox.attackgraph.attackgraph import AttackGraph, AttackGraphNode from maltoolbox.attackgraph.analyzers.apriori import ( propagate_viability_from_unviable_node, prune_unviable_and_unnecessary_nodes, @@ -63,25 +63,24 @@ def test_analyzers_apriori_propagate_viability_from_unviable_node(dummy_lang_gra attack_steps['DummyOrAttackStep'] dummy_defense_attack_step = dummy_lang_graph.assets['DummyAsset'].\ attack_steps['DummyDefenseAttackStep'] - attack_graph = AttackGraph(dummy_lang_graph) # Create a graph of nodes according to above diagram - node1 = attack_graph.add_node( + node1 = AttackGraphNode( lg_attack_step = dummy_defense_attack_step ) - node2 = attack_graph.add_node( + node2 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node3 = attack_graph.add_node( + node3 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node4 = attack_graph.add_node( + node4 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node5 = attack_graph.add_node( + node5 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node6 = attack_graph.add_node( + node6 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) diff --git a/tests/attackgraph/test_attacker.py b/tests/attackgraph/test_attacker.py index 8547e597..2e9aec9a 100644 --- a/tests/attackgraph/test_attacker.py +++ b/tests/attackgraph/test_attacker.py @@ -10,14 +10,13 @@ def test_attacker_to_dict(dummy_lang_graph: LanguageGraph): dummy_or_attack_step = dummy_lang_graph.assets['DummyAsset'].\ attack_steps['DummyOrAttackStep'] - attack_graph = AttackGraph(dummy_lang_graph) - node1 = attack_graph.add_node( + node1 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) attacker = Attacker("Test Attacker", set(), {node1}) assert attacker.to_dict() == { - "id": None, + "id": 0, "name": "Test Attacker", "entry_points": {}, "reached_attack_steps": { @@ -30,15 +29,12 @@ def test_attacker_compromise(dummy_lang_graph: LanguageGraph): dummy_or_attack_step = dummy_lang_graph.assets['DummyAsset'].\ attack_steps['DummyOrAttackStep'] - attack_graph = AttackGraph(dummy_lang_graph) - node1 = attack_graph.add_node( + node1 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) attacker = Attacker("Test Attacker", set(), set()) assert not attacker.entry_points - attack_graph = AttackGraph(dummy_lang_graph) - attack_graph.add_attacker(attacker) attacker.compromise(node1) assert attacker.reached_attack_steps == {node1} @@ -55,9 +51,8 @@ def test_attacker_undo_compromise(dummy_lang_graph: LanguageGraph): dummy_or_attack_step = dummy_lang_graph.assets['DummyAsset'].\ attack_steps['DummyOrAttackStep'] - attack_graph = AttackGraph(dummy_lang_graph) - node1 = attack_graph.add_node( + node1 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) attacker = Attacker("attacker1", set(), set()) diff --git a/tests/attackgraph/test_node.py b/tests/attackgraph/test_node.py index 504723f0..4c8cf063 100644 --- a/tests/attackgraph/test_node.py +++ b/tests/attackgraph/test_node.py @@ -21,27 +21,26 @@ def test_attackgraphnode(dummy_lang_graph: LanguageGraph): attack_steps['DummyAndAttackStep'] dummy_defense_attack_step = dummy_lang_graph.assets['DummyAsset'].\ attack_steps['DummyDefenseAttackStep'] - attack_graph = AttackGraph(dummy_lang_graph) # Create a graph of nodes according to above diagram - node1 = attack_graph.add_node( + node1 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node2 = attack_graph.add_node( + node2 = AttackGraphNode( lg_attack_step = dummy_defense_attack_step ) node2.defense_status = 1.0 - node3 = attack_graph.add_node( + node3 = AttackGraphNode( lg_attack_step = dummy_defense_attack_step ) node3.defense_status = 0.0 - node4 = attack_graph.add_node( + node4 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node5 = attack_graph.add_node( + node5 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) - node6 = attack_graph.add_node( + node6 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) @@ -61,8 +60,6 @@ def test_attackgraphnode(dummy_lang_graph: LanguageGraph): reached_attack_steps = set() ) - attack_graph.add_attacker(attacker) - node6.compromise(attacker) assert node6.compromised_by == {attacker} assert node6.is_compromised() diff --git a/tests/attackgraph/test_query.py b/tests/attackgraph/test_query.py index 9d8105cf..f4ddbb48 100644 --- a/tests/attackgraph/test_query.py +++ b/tests/attackgraph/test_query.py @@ -13,7 +13,6 @@ def test_query_is_node_traversable_by_attacker(dummy_lang_graph: LanguageGraph): attack_steps['DummyOrAttackStep'] dummy_and_attack_step = dummy_lang_graph.assets['DummyAsset'].\ attack_steps['DummyAndAttackStep'] - attack_graph = AttackGraph(dummy_lang_graph) # An attacker with no meaningful data attacker = Attacker( @@ -21,32 +20,33 @@ def test_query_is_node_traversable_by_attacker(dummy_lang_graph: LanguageGraph): entry_points = set(), reached_attack_steps = set() ) - attack_graph.add_attacker(attacker) # Node1 should be traversable since node type is OR - node1 = attack_graph.add_node( + node1 = AttackGraphNode( lg_attack_step = dummy_or_attack_step ) traversable = is_node_traversable_by_attacker(node1, attacker) assert traversable # Node2 should be traversable since node has no parents - node2 = attack_graph.add_node( + node2 = AttackGraphNode( lg_attack_step = dummy_and_attack_step ) + traversable = is_node_traversable_by_attacker(node2, attacker) assert traversable # Node 4 should not be traversable since node has type AND # and it has two parents that are not compromised by attacker - node3 = attack_graph.add_node( + node3 = AttackGraphNode( lg_attack_step = dummy_and_attack_step ) - node4 = attack_graph.add_node( + node4 = AttackGraphNode( lg_attack_step = dummy_and_attack_step ) node4.parents = {node2, node3} node2.children = {node4} node3.children = {node4} + traversable = is_node_traversable_by_attacker(node4, attacker) assert not traversable