Skip to content

Commit b3024e4

Browse files
committed
Rename probability related methods and classes away from TTC since they can be used with any probabilities
1 parent c765e8a commit b3024e4

File tree

3 files changed

+28
-31
lines changed

3 files changed

+28
-31
lines changed

maltoolbox/language/languagegraph.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,8 @@ class LanguageGraph():
654654
"""Graph representation of a MAL language"""
655655
def __init__(self, lang: Optional[dict] = None):
656656
self.assets: dict = {}
657-
self.load_predefined_ttcs()
657+
self.predef_ttcs = load_dict_from_yaml_file(
658+
'maltoolbox/language/predefined_ttcs.yml')
658659
if lang is not None:
659660
self._lang_spec: dict = lang
660661
self.metadata = {
@@ -696,11 +697,6 @@ def from_mar_archive(cls, mar_archive: str) -> LanguageGraph:
696697
return LanguageGraph(json.loads(langspec))
697698

698699

699-
def load_predefined_ttcs(self):
700-
"""Load the predefined ttcs into a dictionary"""
701-
self.predef_ttcs = load_dict_from_yaml_file(
702-
'maltoolbox/language/predefined_ttcs.yml')
703-
704700
def replace_if_predef_ttc(self, ttc_entry: dict) -> dict:
705701
"""
706702
If the TTC provided is a predefined name replace it with the

maltoolbox/probs_utils.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@
77

88
logger = logging.getLogger(__name__)
99

10-
class TTCCalculation(Enum):
10+
class ProbCalculationMethod(Enum):
1111
SAMPLE = 1
1212
EXPECTED = 2
1313

14-
def sample_ttc(probs_dict):
15-
"""Calculate the sampled value from a ttc distribution function
14+
def sample_prob(probs_dict):
15+
"""Calculate the sampled value from a probability distribution function
1616
Arguments:
1717
probs_dict - a dictionary containing the probability distribution
18-
function that is part of a TTC
18+
function
1919
2020
Return:
2121
The float value obtained from calculating the sampled value corresponding
2222
to the function provided.
2323
"""
2424
if probs_dict['type'] != 'function':
25-
raise ValueError('Sample TTC function requires a function '
25+
raise ValueError('Sample probability method requires a function '
2626
f'probability distribution, but got "{probs_dict["type"]}"')
2727

2828
match(probs_dict['name']):
@@ -67,19 +67,20 @@ def sample_ttc(probs_dict):
6767
f'function encountered {probs_dict["name"]}!')
6868

6969

70-
def expected_ttc(probs_dict):
71-
"""Calculate the expected value from a ttc distribution function
70+
def expected_prob(probs_dict):
71+
"""Calculate the expected value from a probability distribution function
7272
Arguments:
7373
probs_dict - a dictionary containing the probability distribution
74-
function that is part of a TTC
74+
function
7575
7676
Return:
7777
The float value obtained from calculating the expected value corresponding
7878
to the function provided.
7979
"""
8080
if probs_dict['type'] != 'function':
81-
raise ValueError('Expected value TTC function requires a function '
82-
f'probability distribution, but got "{probs_dict["type"]}"')
81+
raise ValueError('Expected value probability method requires a '
82+
'function probability distribution, but got '
83+
f'"{probs_dict["type"]}"')
8384

8485
match(probs_dict['name']):
8586
case 'Bernoulli':
@@ -122,16 +123,16 @@ def expected_ttc(probs_dict):
122123
f'function encountered {probs_dict["name"]}!')
123124

124125

125-
def calculate_ttc(probs_dict: dict, method: TTCCalculation) -> float:
126-
"""Calculate the value from a ttc distribution
126+
def calculate_prob(probs_dict: dict, method: ProbCalculationMethod) -> float:
127+
"""Calculate the value from a probability distribution
127128
Arguments:
128129
probs_dict - a dictionary containing the probability distribution
129-
corresponding to the TTC
130-
method - the method to use in calculating the TTC
130+
function
131+
method - the method to use in calculating the probability
131132
values(currently supporting sampled or expected values)
132133
133134
Return:
134-
The float value obtained from calculating the TTC probability distribution.
135+
The float value obtained from calculating the probability distribution.
135136
136137
TTC Distributions in MAL:
137138
https://mal-lang.org/mal-langspec/apidocs/org.mal_lang.langspec/org/mal_lang/langspec/ttc/TtcDistribution.html
@@ -142,8 +143,8 @@ def calculate_ttc(probs_dict: dict, method: TTCCalculation) -> float:
142143
match(probs_dict['type']):
143144
case 'addition' | 'subtraction' | 'multiplication' | \
144145
'division' | 'exponentiation':
145-
lv = calculate_ttc(probs_dict['lhs'], method)
146-
rv = calculate_ttc(probs_dict['rhs'], method)
146+
lv = calculate_prob(probs_dict['lhs'], method)
147+
rv = calculate_prob(probs_dict['rhs'], method)
147148
match(probs_dict['type']):
148149
case 'addition':
149150
return lv + rv
@@ -161,12 +162,12 @@ def calculate_ttc(probs_dict: dict, method: TTCCalculation) -> float:
161162

162163
case 'function':
163164
match(method):
164-
case TTCCalculation.SAMPLE:
165-
return sample_ttc(probs_dict)
166-
case TTCCalculation.EXPECTED:
167-
return expected_ttc(probs_dict)
165+
case ProbCalculationMethod.SAMPLE:
166+
return sample_prob(probs_dict)
167+
case ProbCalculationMethod.EXPECTED:
168+
return expected_prob(probs_dict)
168169
case _:
169-
raise ValueError('Unknown TTC Calculation method '
170+
raise ValueError('Unknown Probability Calculation method '
170171
f'encountered {method}!')
171172

172173
case _:

tests/test_probs_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from maltoolbox.model import Model
66
from maltoolbox.attackgraph.attackgraph import AttackGraph
7-
from maltoolbox.probs_utils import calculate_ttc, TTCCalculation
7+
from maltoolbox.probs_utils import calculate_prob, ProbCalculationMethod
88

99
def test_probs_utils(model: Model):
1010
"""Test TTC calculation for nodes"""
@@ -24,8 +24,8 @@ def test_probs_utils(model: Model):
2424

2525
for node in attack_graph.nodes.values():
2626
#TODO: Actually check some of the results
27-
calculate_ttc(node.ttc, TTCCalculation.SAMPLE)
27+
calculate_prob(node.ttc, ProbCalculationMethod.SAMPLE)
2828

2929
for node in attack_graph.nodes.values():
3030
#TODO: Actually check some of the results
31-
calculate_ttc(node.ttc, TTCCalculation.EXPECTED)
31+
calculate_prob(node.ttc, ProbCalculationMethod.EXPECTED)

0 commit comments

Comments
 (0)