Skip to content

Commit 24b1874

Browse files
committed
Handle ids in classes themselves
1 parent d05f70a commit 24b1874

File tree

7 files changed

+74
-144
lines changed

7 files changed

+74
-144
lines changed

maltoolbox/attackgraph/attacker.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,25 @@
1515

1616
class Attacker:
1717

18+
_max_id = -1
19+
1820
def __init__(
1921
self,
2022
name: str,
21-
entry_points: set[AttackGraphNode],
22-
reached_attack_steps: set[AttackGraphNode],
23-
attacker_id: Optional[int] = None
23+
entry_points: set[AttackGraphNode] = None,
24+
reached_attack_steps: set[AttackGraphNode] = None,
25+
attacker_id: Optional[int] = None,
2426
):
2527
self.name = name
26-
self.entry_points = entry_points
27-
self.reached_attack_steps = reached_attack_steps
28-
self.id = attacker_id
28+
self.entry_points = entry_points or set()
29+
self.reached_attack_steps = reached_attack_steps or set()
30+
31+
Attacker._max_id = max(Attacker._max_id + 1, attacker_id or 0)
32+
self.id = Attacker._max_id
33+
34+
@staticmethod
35+
def reset_ids(id=None):
36+
Attacker._max_id = id if id is not None else -1
2937

3038
def to_dict(self) -> dict:
3139
attacker_dict: dict = {
@@ -60,12 +68,15 @@ def __deepcopy__(self, memo) -> Attacker:
6068
if id(self) in memo:
6169
return memo[id(self)]
6270

71+
old_max_id = Attacker._max_id
72+
Attacker.reset_ids()
6373
copied_attacker = Attacker(
6474
name = self.name,
6575
attacker_id = self.id,
6676
entry_points = set(),
6777
reached_attack_steps = set()
6878
)
79+
Attacker.reset_ids(old_max_id)
6980

7081
# Remember that self was already copied
7182
memo[id(self)] = copied_attacker

maltoolbox/attackgraph/attackgraph.py

Lines changed: 20 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def __init__(self, lang_graph, model: Optional[Model] = None):
9090

9191
self.model = model
9292
self.lang_graph = lang_graph
93-
self.next_node_id = 0
94-
self.next_attacker_id = 0
9593
if self.model is not None:
9694
self._generate_graph()
9795

@@ -154,10 +152,6 @@ def __deepcopy__(self, memo):
154152
copied_attackgraph._full_name_to_node = \
155153
copy.deepcopy(self._full_name_to_node, memo)
156154

157-
# Copy counters
158-
copied_attackgraph.next_node_id = self.next_node_id
159-
copied_attackgraph.next_attacker_id = self.next_attacker_id
160-
161155
return copied_attackgraph
162156

163157
def save_to_file(self, filename: str) -> None:
@@ -201,15 +195,17 @@ def _from_dict(
201195
node_dict['lang_graph_attack_step'])
202196
lg_attack_step = lang_graph.assets[lg_asset_name].\
203197
attack_steps[lg_attack_step_name]
204-
ag_node = attack_graph.add_node(
198+
ag_node = AttackGraphNode(
205199
lg_attack_step = lg_attack_step,
206-
node_id = node_dict['id'],
207200
model_asset = node_asset,
201+
node_id = node_dict['id'],
208202
defense_status = node_dict.get('defense_status', None),
209203
existence_status = node_dict.get('existence_status', None)
210204
)
211205
ag_node.tags = set(node_dict.get('tags', []))
212206
ag_node.extras = node_dict.get('extras', {})
207+
attack_graph.nodes[ag_node.id] = ag_node
208+
attack_graph._full_name_to_node[ag_node.full_name] = ag_node
213209

214210
if node_asset:
215211
# Add AttackGraphNode to attack_step_nodes of asset
@@ -249,20 +245,14 @@ def _from_dict(
249245
_ag_node.parents.add(parent)
250246

251247
for attacker in serialized_attackers.values():
252-
ag_attacker = Attacker(
253-
name = attacker['name'],
254-
entry_points = set(),
255-
reached_attack_steps = set()
256-
)
257248
attack_graph.add_attacker(
258-
attacker = ag_attacker,
259-
attacker_id = int(attacker['id']),
249+
attacker=Attacker(name=attacker['name']),
260250
entry_points = [
261251
int(node_id) # Convert to int since they can be strings
262252
for node_id in attacker['entry_points'].keys()
263253
],
264-
reached_attack_steps = [
265-
int(node_id) # Convert to int since they can be strings
254+
reached_attack_steps=[
255+
int(node_id) # Convert to int since they can be strings
266256
for node_id in attacker['reached_attack_steps'].keys()
267257
]
268258
)
@@ -323,18 +313,15 @@ def attach_attackers(self) -> None:
323313
'Attach attackers from "%s" model to the graph.', self.model.name
324314
)
325315

316+
Attacker.reset_ids()
326317
for attacker_info in self.model.attackers:
327318

328319
if not attacker_info.name:
329320
msg = "Can not attach attacker without name"
330321
logger.error(msg)
331322
raise AttackGraphException(msg)
332323

333-
attacker = Attacker(
334-
name = attacker_info.name,
335-
entry_points = set(),
336-
reached_attack_steps = set()
337-
)
324+
attacker = Attacker(attacker_info.name)
338325
self.add_attacker(attacker)
339326

340327
for (asset, attack_steps) in attacker_info.entry_points:
@@ -520,6 +507,7 @@ def _generate_graph(self) -> None:
520507
Generate the attack graph based on the original model instance and the
521508
MAL language specification provided at initialization.
522509
"""
510+
AttackGraphNode.reset_ids()
523511

524512
if not self.model:
525513
msg = "Can not generate AttackGraph without model"
@@ -580,12 +568,14 @@ def _generate_graph(self) -> None:
580568
case _:
581569
pass
582570

583-
ag_node = self.add_node(
571+
ag_node = AttackGraphNode(
584572
lg_attack_step = attack_step,
585573
model_asset = asset,
586574
defense_status = defense_status,
587575
existence_status = existence_status
588576
)
577+
self.nodes[ag_node.id] = ag_node
578+
self._full_name_to_node[ag_node.full_name] = ag_node
589579
attack_step_nodes.append(ag_node)
590580

591581
asset.attack_step_nodes = attack_step_nodes
@@ -669,64 +659,6 @@ def regenerate_graph(self) -> None:
669659
self.attackers = {}
670660
self._generate_graph()
671661

672-
def add_node(
673-
self,
674-
lg_attack_step: LanguageGraphAttackStep,
675-
node_id: Optional[int] = None,
676-
model_asset: Optional[ModelAsset] = None,
677-
defense_status: Optional[float] = None,
678-
existence_status: Optional[bool] = None
679-
) -> AttackGraphNode:
680-
"""Create and add a node to the graph
681-
Arguments:
682-
lg_attack_step - the language graph attack step that corresponds
683-
to the attack graph node to create
684-
node_id - id to assign to the newly created node, usually
685-
provided only when loading an existing attack
686-
graph from a file. If not provided the id will
687-
be set to the next highest id available.
688-
model_asset - the model asset that corresponds to the attack
689-
step node. While optional it is highly
690-
recommended that this be provided. It should
691-
only be ommitted if the model which was used to
692-
generate the attack graph is not available when
693-
loading an attack graph from a file.
694-
defese_status - the defense status of the node. Only, relevant
695-
for defense type nodes. A value between 0.0 and
696-
1.0 is expected.
697-
existence_status - the existence status of the node. Only, relevant
698-
for exist and notExist type nodes.
699-
700-
Return:
701-
The newly created attack step node.
702-
"""
703-
node_id = node_id if node_id is not None else self.next_node_id
704-
if node_id in self.nodes:
705-
raise ValueError(f'Node index {node_id} already in use.')
706-
self.next_node_id = max(node_id + 1, self.next_node_id)
707-
708-
if logger.isEnabledFor(logging.DEBUG):
709-
# Avoid running json.dumps when not in debug
710-
logger.debug('Create and add to attackgraph node of type "%s" '
711-
'with id:%d.\n' % (
712-
lg_attack_step.full_name,
713-
node_id
714-
))
715-
716-
717-
node = AttackGraphNode(
718-
node_id = node_id,
719-
lg_attack_step = lg_attack_step,
720-
model_asset = model_asset,
721-
defense_status = defense_status,
722-
existence_status = existence_status
723-
)
724-
725-
self.nodes[node_id] = node
726-
self._full_name_to_node[node.full_name] = node
727-
728-
return node
729-
730662
def remove_node(self, node: AttackGraphNode) -> None:
731663
"""Remove node from attack graph
732664
Arguments:
@@ -748,9 +680,8 @@ def remove_node(self, node: AttackGraphNode) -> None:
748680
def add_attacker(
749681
self,
750682
attacker: Attacker,
751-
attacker_id: Optional[int] = None,
752-
entry_points: list[int] = [],
753-
reached_attack_steps: list[int] = []
683+
entry_points: list[int] = None,
684+
reached_attack_steps: list[int] = None
754685
):
755686
"""Add an attacker to the graph
756687
Arguments:
@@ -765,33 +696,19 @@ def add_attacker(
765696
"""
766697

767698
if logger.isEnabledFor(logging.DEBUG):
768-
# Avoid running json.dumps when not in debug
769-
if attacker_id is not None:
770-
logger.debug('Add attacker "%s" with id:%d.',
771-
attacker.name,
772-
attacker_id
773-
)
774-
else:
775-
logger.debug('Add attacker "%s" without id.',
776-
attacker.name
777-
)
778-
779-
attacker.id = attacker_id or self.next_attacker_id
780-
if attacker.id in self.attackers:
781-
raise ValueError(f'Attacker index {attacker_id} already in use.')
699+
logger.debug('Add attacker "%s" with id:%d.', attacker.name, attacker.id)
782700

783-
self.next_attacker_id = max(attacker.id + 1, self.next_attacker_id)
784-
for node_id in reached_attack_steps:
785-
node = self.nodes[node_id]
701+
for node_id in reached_attack_steps or []:
702+
node = self.get_node_by_id(node_id)
786703
if node:
787704
attacker.compromise(node)
788705
else:
789706
msg = ("Could not find node with id %d"
790707
"in reached attack steps.")
791708
logger.error(msg, node_id)
792709
raise AttackGraphException(msg % node_id)
793-
for node_id in entry_points:
794-
node = self.nodes[node_id]
710+
for node_id in entry_points or []:
711+
node = self.get_node_by_id(int(node_id))
795712
if node:
796713
attacker.entry_points.add(node)
797714
else:

maltoolbox/attackgraph/node.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
class AttackGraphNode:
1717
"""Node part of AttackGraph"""
1818

19+
_max_id = -1
20+
1921
def __init__(
2022
self,
21-
node_id: int,
2223
lg_attack_step: LanguageGraphAttackStep,
2324
model_asset: Optional[ModelAsset] = None,
2425
defense_status: Optional[float] = None,
25-
existence_status: Optional[bool] = None
26+
existence_status: Optional[bool] = None,
27+
node_id: int = None,
2628
):
2729
self.lg_attack_step = lg_attack_step
2830
self.name = lg_attack_step.name
@@ -31,7 +33,9 @@ def __init__(
3133
self.tags = lg_attack_step.tags
3234
self.detectors = lg_attack_step.detectors
3335

34-
self.id = node_id
36+
AttackGraphNode._max_id = max(self._max_id + 1, node_id or 0)
37+
self.id = AttackGraphNode._max_id
38+
3539
self.model_asset = model_asset
3640
self.defense_status = defense_status
3741
self.existence_status = existence_status
@@ -43,6 +47,10 @@ def __init__(
4347
self.compromised_by: set[Attacker] = set()
4448
self.extras: dict = {}
4549

50+
@staticmethod
51+
def reset_ids(id = None):
52+
AttackGraphNode._max_id = id if id is not None else -1
53+
4654
def to_dict(self) -> dict:
4755
"""Convert node to dictionary"""
4856
node_dict: dict = {
@@ -98,11 +106,14 @@ def __deepcopy__(self, memo) -> AttackGraphNode:
98106
if id(self) in memo:
99107
return memo[id(self)]
100108

109+
old_max_id = AttackGraphNode._max_id
110+
AttackGraphNode.reset_ids()
101111
copied_node = AttackGraphNode(
102112
node_id = self.id,
103113
model_asset = self.model_asset,
104114
lg_attack_step = self.lg_attack_step
105115
)
116+
AttackGraphNode.reset_ids(old_max_id)
106117

107118
copied_node.tags = copy.deepcopy(self.tags, memo)
108119
copied_node.extras = copy.deepcopy(self.extras, memo)

tests/attackgraph/test_analyzer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for analyzers"""
22

3-
from maltoolbox.attackgraph.attackgraph import AttackGraph
3+
from maltoolbox.attackgraph.attackgraph import AttackGraph, AttackGraphNode
44
from maltoolbox.attackgraph.analyzers.apriori import (
55
propagate_viability_from_unviable_node,
66
prune_unviable_and_unnecessary_nodes,
@@ -63,25 +63,24 @@ def test_analyzers_apriori_propagate_viability_from_unviable_node(dummy_lang_gra
6363
attack_steps['DummyOrAttackStep']
6464
dummy_defense_attack_step = dummy_lang_graph.assets['DummyAsset'].\
6565
attack_steps['DummyDefenseAttackStep']
66-
attack_graph = AttackGraph(dummy_lang_graph)
6766

6867
# Create a graph of nodes according to above diagram
69-
node1 = attack_graph.add_node(
68+
node1 = AttackGraphNode(
7069
lg_attack_step = dummy_defense_attack_step
7170
)
72-
node2 = attack_graph.add_node(
71+
node2 = AttackGraphNode(
7372
lg_attack_step = dummy_or_attack_step
7473
)
75-
node3 = attack_graph.add_node(
74+
node3 = AttackGraphNode(
7675
lg_attack_step = dummy_or_attack_step
7776
)
78-
node4 = attack_graph.add_node(
77+
node4 = AttackGraphNode(
7978
lg_attack_step = dummy_or_attack_step
8079
)
81-
node5 = attack_graph.add_node(
80+
node5 = AttackGraphNode(
8281
lg_attack_step = dummy_or_attack_step
8382
)
84-
node6 = attack_graph.add_node(
83+
node6 = AttackGraphNode(
8584
lg_attack_step = dummy_or_attack_step
8685
)
8786

0 commit comments

Comments
 (0)