From 576b8d0685989427e52c35312e4ca77859dc8c72 Mon Sep 17 00:00:00 2001 From: Nikolaos Kakouros Date: Sat, 1 Feb 2025 04:37:55 +0200 Subject: [PATCH] Adds ruff lint workflow, config and autofixes issues --- .github/workflows/lint.yml | 29 + docs/conf.py | 9 +- maltoolbox/__init__.py | 31 +- maltoolbox/__main__.py | 42 +- maltoolbox/attackgraph/__init__.py | 9 +- maltoolbox/attackgraph/analyzers/apriori.py | 123 +- maltoolbox/attackgraph/attacker.py | 64 +- maltoolbox/attackgraph/attackgraph.py | 563 +- maltoolbox/attackgraph/node.py | 92 +- maltoolbox/attackgraph/query.py | 105 +- maltoolbox/exceptions.py | 23 +- maltoolbox/file_utils.py | 30 +- maltoolbox/ingestors/neo4j.py | 183 +- maltoolbox/language/__init__.py | 24 +- maltoolbox/language/classes_factory.py | 258 +- maltoolbox/language/compiler/__init__.py | 10 +- maltoolbox/language/compiler/mal_lexer.py | 2863 ++++++++++- maltoolbox/language/compiler/mal_parser.py | 4545 +++++++++++++---- maltoolbox/language/compiler/mal_visitor.py | 248 +- maltoolbox/language/languagegraph.py | 1289 ++--- maltoolbox/model.py | 528 +- maltoolbox/translators/securicad.py | 104 +- maltoolbox/translators/updater.py | 97 +- maltoolbox/wrappers.py | 29 +- pyproject.toml | 163 +- tests/attackgraph/test_analyzer.py | 84 +- tests/attackgraph/test_attacker.py | 69 +- tests/attackgraph/test_attackgraph.py | 773 +-- tests/attackgraph/test_node.py | 65 +- tests/attackgraph/test_query.py | 45 +- tests/conftest.py | 43 +- tests/language/test_classes_factory.py | 20 +- tests/language/test_languagegraph.py | 39 +- tests/test_model.py | 495 +- tests/test_wrappers.py | 8 +- .../translators/test_securicad_translator.py | 9 +- 36 files changed, 9152 insertions(+), 3959 deletions(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..a89d6246 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,29 @@ +name: Lint + +on: + push: + branches: + - main + pull_request: ~ + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + + - name: Install Ruff + run: | + python -m pip install --upgrade pip + pip install ruff + + - name: Check Code with Ruff + run: | + ruff check --statistics + ruff format --check diff --git a/docs/conf.py b/docs/conf.py index 53c93c71..0ddb9f6e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,13 +19,16 @@ project = 'MAL Toolbox' copyright = '2024, Andrei Buhaiu, Giuseppe Nebbione, Nikolaos Kakouros, Jakob Nyberg, Joakim Loxdal' -author = 'Andrei Buhaiu, Giuseppe Nebbione, Nikolaos Kakouros, Jakob Nyberg, Joakim Loxdal' +author = ( + 'Andrei Buhaiu, Giuseppe Nebbione, Nikolaos Kakouros, Jakob Nyberg, Joakim Loxdal' +) # -- General configuration --------------------------------------------------- import os import sys + sys.path.insert(0, os.path.abspath('../')) # Source code dir relative to this file # Add any Sphinx extension module names here, as strings. They can be @@ -53,9 +56,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' # pip install sphinx_rtd_theme +html_theme = 'sphinx_rtd_theme' # pip install sphinx_rtd_theme # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ['_static'] \ No newline at end of file +# html_static_path = ['_static'] diff --git a/maltoolbox/__init__.py b/maltoolbox/__init__.py index 8d06fd5b..66d46921 100644 --- a/maltoolbox/__init__.py +++ b/maltoolbox/__init__.py @@ -1,4 +1,3 @@ -# -*- encoding: utf-8 -*- # MAL Toolbox v0.2.0 # Copyright 2024, Andrei Buhaiu. # @@ -16,41 +15,40 @@ # -""" -MAL-Toolbox Framework -""" +"""MAL-Toolbox Framework.""" __title__ = 'maltoolbox' __version__ = '0.2.0' -__authors__ = ['Andrei Buhaiu', +__authors__ = [ + 'Andrei Buhaiu', 'Giuseppe Nebbione', 'Nikolaos Kakouros', 'Jakob Nyberg', - 'Joakim Loxdal'] + 'Joakim Loxdal', +] __license__ = 'Apache 2.0' __docformat__ = 'restructuredtext en' __all__ = () -import os import configparser import logging +import os ERROR_INCORRECT_CONFIG = 1 -CONFIGFILE = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "default.conf" -) +CONFIGFILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'default.conf') config = configparser.ConfigParser() config.read(CONFIGFILE) if 'logging' not in config: - raise ValueError('Config file is missing essential information, cannot proceed.') + msg = 'Config file is missing essential information, cannot proceed.' + raise ValueError(msg) if 'log_file' not in config['logging']: - raise ValueError('Config file is missing a log_file location, cannot proceed.') + msg = 'Config file is missing a log_file location, cannot proceed.' + raise ValueError(msg) log_configs = { 'log_file': config['logging']['log_file'], @@ -60,9 +58,11 @@ 'langspec_file': config['logging']['langspec_file'], } -os.makedirs(os.path.dirname(log_configs['log_file']), exist_ok = True) +os.makedirs(os.path.dirname(log_configs['log_file']), exist_ok=True) -formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M') +formatter = logging.Formatter( + '%(asctime)s %(name)-12s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M' +) file_handler = logging.FileHandler(log_configs['log_file'], mode='w') file_handler.setFormatter(formatter) @@ -78,7 +78,6 @@ if 'neo4j' in config: for term in ['uri', 'username', 'password', 'dbname']: if term not in config['neo4j']: - msg = ( 'Config file is missing essential Neo4J ' f'information: {term}, cannot proceed.' diff --git a/maltoolbox/__main__.py b/maltoolbox/__main__.py index 385e10bd..6a507446 100644 --- a/maltoolbox/__main__.py +++ b/maltoolbox/__main__.py @@ -1,5 +1,4 @@ -""" -Command-line interface for MAL toolbox operations +"""Command-line interface for MAL toolbox operations. Usage: maltoolbox attack-graph generate [options] @@ -20,58 +19,61 @@ - If --neo4j is used, the Neo4j instance should be running. The connection parameters required for this app to reach the Neo4j instance should be defined in the default.conf file. + """ -import logging import json +import logging + import docopt from maltoolbox.wrappers import create_attack_graph + from . import log_configs, neo4j_configs -from .language.compiler import MalCompiler from .ingestors import neo4j +from .language.compiler import MalCompiler logger = logging.getLogger(__name__) -def generate_attack_graph( - model_file: str, - lang_file: str, - send_to_neo4j: bool - ) -> None: - """Create an attack graph and optionally send to neo4j - + +def generate_attack_graph(model_file: str, lang_file: str, send_to_neo4j: bool) -> None: + """Create an attack graph and optionally send to neo4j. + Args: model_file - path to the model file lang_file - path to the language file send_to_neo4j - whether to ingest into neo4j or not + """ attack_graph = create_attack_graph(lang_file, model_file) if log_configs['attackgraph_file']: - attack_graph.save_to_file( - log_configs['attackgraph_file'] - ) + attack_graph.save_to_file(log_configs['attackgraph_file']) if send_to_neo4j: logger.debug('Ingest model graph into Neo4J database.') - neo4j.ingest_model(attack_graph.model, + neo4j.ingest_model( + attack_graph.model, neo4j_configs['uri'], neo4j_configs['username'], neo4j_configs['password'], neo4j_configs['dbname'], - delete=True) + delete=True, + ) logger.debug('Ingest attack graph into Neo4J database.') - neo4j.ingest_attack_graph(attack_graph, + neo4j.ingest_attack_graph( + attack_graph, neo4j_configs['uri'], neo4j_configs['username'], neo4j_configs['password'], neo4j_configs['dbname'], - delete=False) + delete=False, + ) def compile(lang_file: str, output_file: str) -> None: - """Compile language and dump into output file""" + """Compile language and dump into output file.""" compiler = MalCompiler() - with open(output_file, "w") as f: + with open(output_file, 'w', encoding='utf-8') as f: json.dump(compiler.compile(lang_file), f, indent=2) diff --git a/maltoolbox/attackgraph/__init__.py b/maltoolbox/attackgraph/__init__.py index a6b0e7c0..54447b60 100644 --- a/maltoolbox/attackgraph/__init__.py +++ b/maltoolbox/attackgraph/__init__.py @@ -1,8 +1,7 @@ -""" -Contains tools used to generate attack graphs from MAL instance +"""Contains tools used to generate attack graphs from MAL instance models and analyze attack graphs. """ -from .attacker import Attacker -from .attackgraph import AttackGraph -from .node import AttackGraphNode +from .attacker import Attacker as Attacker +from .attackgraph import AttackGraph as AttackGraph +from .node import AttackGraphNode as AttackGraphNode diff --git a/maltoolbox/attackgraph/analyzers/apriori.py b/maltoolbox/attackgraph/analyzers/apriori.py index f34dc8dd..6c134414 100644 --- a/maltoolbox/attackgraph/analyzers/apriori.py +++ b/maltoolbox/attackgraph/analyzers/apriori.py @@ -1,5 +1,4 @@ -""" -MAL-Toolbox Attack Graph Apriori Analyzer Submodule +"""MAL-Toolbox Attack Graph Apriori Analyzer Submodule. This submodule contains analyzers that are relevant before attackers are even connected to the attack graph. @@ -12,23 +11,28 @@ """ from __future__ import annotations -from typing import Optional + import logging +from typing import TYPE_CHECKING -from ..attackgraph import AttackGraph -from ..node import AttackGraphNode +if TYPE_CHECKING: + from maltoolbox.attackgraph.attackgraph import AttackGraph + from maltoolbox.attackgraph.node import AttackGraphNode logger = logging.getLogger(__name__) + def propagate_viability_from_node(node: AttackGraphNode) -> None: - """ - Arguments: + """Arguments: node - the attack graph node from which to propagate the viable - status + status. + """ logger.debug( 'Propagate viability from "%s"(%d) with viability status %s.', - node.full_name, node.id, node.is_viable + node.full_name, + node.id, + node.is_viable, ) for child in node.children: original_value = child.is_viable @@ -44,18 +48,20 @@ def propagate_viability_from_node(node: AttackGraphNode) -> None: def propagate_necessity_from_node(node: AttackGraphNode) -> None: - """ - Arguments: + """Arguments: node - the attack graph node from which to propagate the necessary - status + status. + """ logger.debug( 'Propagate necessity from "%s"(%d) with necessity status %s.', - node.full_name, node.id, node.is_necessary + node.full_name, + node.id, + node.is_necessary, ) if node.ttc and 'name' in node.ttc: - if node.ttc['name'] not in ['Enabled', 'Disabled', 'Instant']: + if node.ttc['name'] not in {'Enabled', 'Disabled', 'Instant'}: # Do not propagate unnecessary state from nodes that have a TTC # probability distribution associated with them. # TODO: Evaluate this more carefully, how do we want to have TTCs @@ -80,23 +86,25 @@ def propagate_necessity_from_node(node: AttackGraphNode) -> None: def evaluate_viability(node: AttackGraphNode) -> None: - """ - Arguments: + """Arguments: graph - the node to evaluate viability for. + """ - match (node.type): + match node.type: case 'exist': - assert isinstance(node.existence_status, bool), \ + assert isinstance(node.existence_status, bool), ( f'Existence status not defined for {node.full_name}.' + ) node.is_viable = node.existence_status case 'notExist': - assert isinstance(node.existence_status, bool), \ + assert isinstance(node.existence_status, bool), ( f'Existence status not defined for {node.full_name}.' + ) node.is_viable = not node.existence_status case 'defense': - assert node.defense_status is not None and \ - 0.0 <= node.defense_status <= 1.0, \ - f'{node.full_name} defense status invalid: {node.defense_status}.' + assert ( + node.defense_status is not None and 0.0 <= node.defense_status <= 1.0 + ), f'{node.full_name} defense status invalid: {node.defense_status}.' node.is_viable = node.defense_status != 1.0 case 'or': node.is_viable = False @@ -107,30 +115,34 @@ def evaluate_viability(node: AttackGraphNode) -> None: for parent in node.parents: node.is_viable = node.is_viable and parent.is_viable case _: - msg = ('Evaluate viability was provided node "%s"(%d) which ' - 'is of unknown type "%s"') + msg = ( + 'Evaluate viability was provided node "%s"(%d) which ' + 'is of unknown type "%s"' + ) logger.error(msg, node.full_name, node.id, node.type) raise ValueError(msg % (node.full_name, node.id, node.type)) def evaluate_necessity(node: AttackGraphNode) -> None: - """ - Arguments: + """Arguments: graph - the node to evaluate necessity for. + """ - match (node.type): + match node.type: case 'exist': - assert isinstance(node.existence_status, bool), \ + assert isinstance(node.existence_status, bool), ( f'Existence status not defined for {node.full_name}.' + ) node.is_necessary = not node.existence_status case 'notExist': - assert isinstance(node.existence_status, bool), \ + assert isinstance(node.existence_status, bool), ( f'Existence status not defined for {node.full_name}.' + ) node.is_necessary = bool(node.existence_status) case 'defense': - assert node.defense_status is not None and \ - 0.0 <= node.defense_status <= 1.0, \ - f'{node.full_name} defense status invalid: {node.defense_status}.' + assert ( + node.defense_status is not None and 0.0 <= node.defense_status <= 1.0 + ), f'{node.full_name} defense status invalid: {node.defense_status}.' node.is_necessary = node.defense_status != 0.0 case 'or': node.is_necessary = True @@ -141,29 +153,31 @@ def evaluate_necessity(node: AttackGraphNode) -> None: for parent in node.parents: node.is_necessary = node.is_necessary or parent.is_necessary case _: - msg = ('Evaluate necessity was provided node "%s"(%d) which ' - 'is of unknown type "%s"') + msg = ( + 'Evaluate necessity was provided node "%s"(%d) which ' + 'is of unknown type "%s"' + ) logger.error(msg, node.full_name, node.id, node.type) raise ValueError(msg % (node.full_name, node.id, node.type)) def evaluate_viability_and_necessity(node: AttackGraphNode) -> None: - """ - Arguments: + """Arguments: graph - the node to evaluate viability and necessity for. + """ evaluate_viability(node) evaluate_necessity(node) def calculate_viability_and_necessity(graph: AttackGraph) -> None: - """ - Arguments: + """Arguments: graph - the attack graph for which we wish to determine the viability and necessity statuses for the nodes. + """ for node in graph.nodes: - if node.type in ['exist', 'notExist', 'defense']: + if node.type in {'exist', 'notExist', 'defense'}: evaluate_viability_and_necessity(node) if not node.is_viable: propagate_viability_from_node(node) @@ -172,23 +186,23 @@ def calculate_viability_and_necessity(graph: AttackGraph) -> None: def prune_unviable_and_unnecessary_nodes(graph: AttackGraph) -> None: - """ - Arguments: + """Arguments: graph - the attack graph for which we wish to remove the the nodes which are not viable or necessary. + """ logger.debug('Prune unviable and unnecessary nodes from the attack graph.') for node in graph.nodes: - if (node.type == 'or' or node.type == 'and') and \ - (not node.is_viable or not node.is_necessary): + if (node.type in {'or', 'and'}) and ( + not node.is_viable or not node.is_necessary + ): graph.remove_node(node) def propagate_viability_from_unviable_node( - unviable_node: AttackGraphNode, - ) -> list[AttackGraphNode]: - """ - Update viability of nodes affected by newly enabled defense + unviable_node: AttackGraphNode, +) -> list[AttackGraphNode]: + """Update viability of nodes affected by newly enabled defense `unviable_node` in the graph and return any attack steps that are no longer viable because of it. @@ -201,22 +215,20 @@ def propagate_viability_from_unviable_node( attack_steps_made_unviable - list of the attack steps that have been made unviable by a defense enabled in the current step. Builds up recursively. - """ + """ attack_steps_made_unviable = [] logger.debug( - 'Update viability for node "%s"(%d)', - unviable_node.full_name, - unviable_node.id + 'Update viability for node "%s"(%d)', unviable_node.full_name, unviable_node.id ) assert not unviable_node.is_viable, ( - "propagate_viability_from_unviable_node should not be called" - f" on viable node {unviable_node.full_name}" + 'propagate_viability_from_unviable_node should not be called' + f' on viable node {unviable_node.full_name}' ) - if unviable_node.type in ('and', 'or'): + if unviable_node.type in {'and', 'or'}: attack_steps_made_unviable.append(unviable_node) for child in unviable_node.children: @@ -229,7 +241,6 @@ def propagate_viability_from_unviable_node( child.is_viable = False if child.is_viable != original_value: - attack_steps_made_unviable += \ - propagate_viability_from_unviable_node(child) + attack_steps_made_unviable += propagate_viability_from_unviable_node(child) return attack_steps_made_unviable diff --git a/maltoolbox/attackgraph/attacker.py b/maltoolbox/attackgraph/attacker.py index 1aa1fe80..e1913dc4 100644 --- a/maltoolbox/attackgraph/attacker.py +++ b/maltoolbox/attackgraph/attacker.py @@ -1,41 +1,39 @@ -""" -MAL-Toolbox Attack Graph Attacker Class -""" +"""MAL-Toolbox Attack Graph Attacker Class.""" from __future__ import annotations -from dataclasses import dataclass, field + import copy import logging - -from typing import Optional +from dataclasses import dataclass, field from typing import TYPE_CHECKING + if TYPE_CHECKING: from .attackgraph import AttackGraphNode logger = logging.getLogger(__name__) + @dataclass class Attacker: name: str entry_points: list[AttackGraphNode] = field(default_factory=list) - reached_attack_steps: list[AttackGraphNode] = \ - field(default_factory=list) - id: Optional[int] = None + reached_attack_steps: list[AttackGraphNode] = field(default_factory=list) + id: int | None = None def to_dict(self) -> dict: attacker_dict: dict = { 'id': self.id, 'name': self.name, 'entry_points': {}, - 'reached_attack_steps': {} + 'reached_attack_steps': {}, } for entry_point in self.entry_points: - attacker_dict['entry_points'][entry_point.id] = \ - entry_point.full_name + attacker_dict['entry_points'][entry_point.id] = entry_point.full_name for attack_step in self.reached_attack_steps: - attacker_dict['reached_attack_steps'][attack_step.id] = \ + attacker_dict['reached_attack_steps'][attack_step.id] = ( attack_step.full_name + ) return attacker_dict @@ -43,50 +41,47 @@ def __repr__(self) -> str: return str(self.to_dict()) def __deepcopy__(self, memo) -> Attacker: - """Deep copy an Attacker""" - + """Deep copy an Attacker.""" # Check if the object is already in the memo dictionary if id(self) in memo: return memo[id(self)] copied_attacker = Attacker( - id = self.id, - name = self.name, + id=self.id, + name=self.name, ) # Remember that self was already copied memo[id(self)] = copied_attacker - copied_attacker.entry_points = copy.deepcopy( - self.entry_points, memo = memo) + copied_attacker.entry_points = copy.deepcopy(self.entry_points, memo=memo) copied_attacker.reached_attack_steps = copy.deepcopy( - self.reached_attack_steps, memo = memo) + self.reached_attack_steps, memo=memo + ) return copied_attacker def compromise(self, node: AttackGraphNode) -> None: - """ - Have the attacke compromise the node given as a parameter. + """Have the attacke compromise the node given as a parameter. Arguments: node - the node that the attacker will compromise - """ + """ logger.debug( 'Attacker "%s"(%d) is compromising node "%s"(%d).', self.name, self.id, node.full_name, - node.id + node.id, ) if node.is_compromised_by(self): logger.info( - 'Attacker "%s"(%d) already compromised node "%s"(%d). ' - 'Do nothing.', + 'Attacker "%s"(%d) already compromised node "%s"(%d). Do nothing.', self.name, self.id, node.full_name, - node.id + node.id, ) return @@ -94,30 +89,27 @@ def compromise(self, node: AttackGraphNode) -> None: self.reached_attack_steps.append(node) def undo_compromise(self, node: AttackGraphNode) -> None: - """ - Remove the attacker from the list of attackers that have compromised + """Remove the attacker from the list of attackers that have compromised the node given as a parameter. Arguments: node - the node that we wish to remove this attacker from. - """ + """ logger.debug( - 'Removing attacker "%s"(%d) from compromised_by ' - 'list of node "%s"(%d).', + 'Removing attacker "%s"(%d) from compromised_by list of node "%s"(%d).', self.name, self.id, node.full_name, - node.id + node.id, ) if not node.is_compromised_by(self): logger.info( - 'Attacker "%s"(%d) had not compromised node "%s"(%d).' - ' Do nothing.', + 'Attacker "%s"(%d) had not compromised node "%s"(%d). Do nothing.', self.name, self.id, node.full_name, - node.id + node.id, ) return diff --git a/maltoolbox/attackgraph/attackgraph.py b/maltoolbox/attackgraph/attackgraph.py index 0da48de3..0e7fedd0 100644 --- a/maltoolbox/attackgraph/attackgraph.py +++ b/maltoolbox/attackgraph/attackgraph.py @@ -1,37 +1,43 @@ -""" -MAL-Toolbox Attack Graph Module -""" +"""MAL-Toolbox Attack Graph Module.""" + from __future__ import annotations + import copy -import logging import json - -from itertools import chain +import logging from typing import TYPE_CHECKING -from .node import AttackGraphNode -from .attacker import Attacker -from ..exceptions import AttackGraphStepExpressionError, AttackGraphException -from ..exceptions import LanguageGraphException -from ..model import Model -from ..language import (LanguageGraph, ExpressionsChain, - disaggregate_attack_step_full_name) -from ..file_utils import ( +from maltoolbox.exceptions import ( + AttackGraphException, + AttackGraphStepExpressionError, + LanguageGraphException, +) +from maltoolbox.file_utils import ( load_dict_from_json_file, load_dict_from_yaml_file, - save_dict_to_file + save_dict_to_file, +) +from maltoolbox.language import ( + ExpressionsChain, + LanguageGraph, + disaggregate_attack_step_full_name, ) +from .attacker import Attacker +from .node import AttackGraphNode if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any + + from maltoolbox.model import Model logger = logging.getLogger(__name__) -class AttackGraph(): - """Graph representation of attack steps""" - def __init__(self, lang_graph, model: Optional[Model] = None): +class AttackGraph: + """Graph representation of attack steps.""" + + def __init__(self, lang_graph, model: Model | None = None) -> None: self.nodes: list[AttackGraphNode] = [] self.attackers: list[Attacker] = [] # Dictionaries used in optimization to get nodes and attackers by id @@ -51,16 +57,16 @@ def __repr__(self) -> str: return f'AttackGraph({len(self.nodes)} nodes)' def _to_dict(self) -> dict: - """Convert AttackGraph to dict""" + """Convert AttackGraph to dict.""" serialized_attack_steps = {} serialized_attackers = {} for ag_node in self.nodes: - serialized_attack_steps[ag_node.full_name] =\ - ag_node.to_dict() + serialized_attack_steps[ag_node.full_name] = ag_node.to_dict() for attacker in self.attackers: serialized_attackers[attacker.name] = attacker.to_dict() - logger.debug('Serialized %d attack steps and %d attackers.' % - (len(self.nodes), len(self.attackers)) + logger.debug( + 'Serialized %d attack steps and %d attackers.' + % (len(self.nodes), len(self.attackers)) ) return { 'attack_steps': serialized_attack_steps, @@ -68,7 +74,6 @@ def _to_dict(self) -> dict: } def __deepcopy__(self, memo): - # Check if the object is already in the memo dictionary if id(self) in memo: return memo[id(self)] @@ -96,16 +101,14 @@ def __deepcopy__(self, memo): # Re-link attacker references for node in self.nodes: if node.compromised_by: - memo[id(node)].compromised_by = copy.deepcopy( - node.compromised_by, memo) + memo[id(node)].compromised_by = copy.deepcopy(node.compromised_by, memo) # Copy lookup dicts - copied_attackgraph._id_to_attacker = \ - copy.deepcopy(self._id_to_attacker, memo) - copied_attackgraph._id_to_node = \ - copy.deepcopy(self._id_to_node, memo) - copied_attackgraph._full_name_to_node = \ - copy.deepcopy(self._full_name_to_node, memo) + copied_attackgraph._id_to_attacker = copy.deepcopy(self._id_to_attacker, memo) + copied_attackgraph._id_to_node = copy.deepcopy(self._id_to_node, memo) + copied_attackgraph._full_name_to_node = copy.deepcopy( + self._full_name_to_node, memo + ) # Copy counters copied_attackgraph.next_node_id = self.next_node_id @@ -114,23 +117,22 @@ def __deepcopy__(self, memo): return copied_attackgraph def save_to_file(self, filename: str) -> None: - """Save to json/yml depending on extension""" + """Save to json/yml depending on extension.""" logger.debug('Save attack graph to file "%s".', filename) return save_dict_to_file(filename, self._to_dict()) @classmethod def _from_dict( - cls, - serialized_object: dict, - lang_graph: LanguageGraph, - model: Optional[Model]=None - ) -> AttackGraph: + cls, + serialized_object: dict, + lang_graph: LanguageGraph, + model: Model | None = None, + ) -> AttackGraph: """Create AttackGraph from dict Args: serialized_object - AttackGraph in dict format - model - Optional Model to add connections to + model - Optional Model to add connections to. """ - attack_graph = AttackGraph(lang_graph) attack_graph.model = model serialized_attack_steps = serialized_object['attack_steps'] @@ -138,28 +140,30 @@ def _from_dict( # Create all of the nodes in the imported attack graph. for node_dict in serialized_attack_steps.values(): - # Recreate asset links if model is available. node_asset = None if model and 'asset' in node_dict: node_asset = model.get_asset_by_name(node_dict['asset']) if node_asset is None: - msg = ('Failed to find asset with id %s' - 'when loading from attack graph dict') - logger.error(msg, node_dict["asset"]) - raise LookupError(msg % node_dict["asset"]) - - lg_asset_name, lg_attack_step_name = \ - disaggregate_attack_step_full_name( - node_dict['lang_graph_attack_step']) - lg_attack_step = lang_graph.assets[lg_asset_name].\ - attack_steps[lg_attack_step_name] + msg = ( + 'Failed to find asset with id %s' + 'when loading from attack graph dict' + ) + logger.error(msg, node_dict['asset']) + raise LookupError(msg % node_dict['asset']) + + lg_asset_name, lg_attack_step_name = disaggregate_attack_step_full_name( + node_dict['lang_graph_attack_step'] + ) + lg_attack_step = lang_graph.assets[lg_asset_name].attack_steps[ + lg_attack_step_name + ] ag_node = AttackGraphNode( type=node_dict['type'], - lang_graph_attack_step = lg_attack_step, + lang_graph_attack_step=lg_attack_step, name=node_dict['name'], ttc=node_dict['ttc'], - asset=node_asset + asset=node_asset, ) if node_asset: @@ -171,16 +175,25 @@ def _from_dict( else: node_asset.attack_step_nodes = [ag_node] - ag_node.defense_status = float(node_dict['defense_status']) if \ - 'defense_status' in node_dict else None - ag_node.existence_status = node_dict['existence_status'] \ - == 'True' if 'existence_status' in node_dict else None - ag_node.is_viable = node_dict['is_viable'] == 'True' if \ - 'is_viable' in node_dict else True - ag_node.is_necessary = node_dict['is_necessary'] == 'True' if \ - 'is_necessary' in node_dict else True - ag_node.tags = set(node_dict['tags']) if \ - 'tags' in node_dict else set() + ag_node.defense_status = ( + float(node_dict['defense_status']) + if 'defense_status' in node_dict + else None + ) + ag_node.existence_status = ( + node_dict['existence_status'] == 'True' + if 'existence_status' in node_dict + else None + ) + ag_node.is_viable = ( + node_dict['is_viable'] == 'True' if 'is_viable' in node_dict else True + ) + ag_node.is_necessary = ( + node_dict['is_necessary'] == 'True' + if 'is_necessary' in node_dict + else True + ) + ag_node.tags = set(node_dict['tags']) if 'tags' in node_dict else set() ag_node.extras = node_dict.get('extras', {}) # Add AttackGraphNode to AttackGraph @@ -190,88 +203,87 @@ def _from_dict( for node_dict in serialized_attack_steps.values(): _ag_node = attack_graph.get_node_by_id(node_dict['id']) if not isinstance(_ag_node, AttackGraphNode): - msg = ('Failed to find node with id %s when loading' - ' attack graph from dict') - logger.error(msg, node_dict["id"]) - raise LookupError(msg % node_dict["id"]) - else: - for child_id in node_dict['children']: - child = attack_graph.get_node_by_id(int(child_id)) - if child is None: - msg = ('Failed to find child node with id %s' - ' when loading from attack graph from dict') - logger.error(msg, child_id) - raise LookupError(msg % child_id) - _ag_node.children.append(child) - - for parent_id in node_dict['parents']: - parent = attack_graph.get_node_by_id(int(parent_id)) - if parent is None: - msg = ('Failed to find parent node with id %s ' - 'when loading from attack graph from dict') - logger.error(msg, parent_id) - raise LookupError(msg % parent_id) - _ag_node.parents.append(parent) + msg = ( + 'Failed to find node with id %s when loading attack graph from dict' + ) + logger.error(msg, node_dict['id']) + raise LookupError(msg % node_dict['id']) + for child_id in node_dict['children']: + child = attack_graph.get_node_by_id(int(child_id)) + if child is None: + msg = ( + 'Failed to find child node with id %s' + ' when loading from attack graph from dict' + ) + logger.error(msg, child_id) + raise LookupError(msg % child_id) + _ag_node.children.append(child) + + for parent_id in node_dict['parents']: + parent = attack_graph.get_node_by_id(int(parent_id)) + if parent is None: + msg = ( + 'Failed to find parent node with id %s ' + 'when loading from attack graph from dict' + ) + logger.error(msg, parent_id) + raise LookupError(msg % parent_id) + _ag_node.parents.append(parent) for attacker in serialized_attackers.values(): ag_attacker = Attacker( - name = attacker['name'], - entry_points = [], - reached_attack_steps = [] + name=attacker['name'], entry_points=[], reached_attack_steps=[] ) attack_graph.add_attacker( - attacker = ag_attacker, - attacker_id = int(attacker['id']), - entry_points = attacker['entry_points'].keys(), - reached_attack_steps = [ - int(node_id) # Convert to int since they can be strings - for node_id in attacker['reached_attack_steps'].keys() - ] + attacker=ag_attacker, + attacker_id=int(attacker['id']), + entry_points=attacker['entry_points'].keys(), + reached_attack_steps=[ + int(node_id) # Convert to int since they can be strings + for node_id in attacker['reached_attack_steps'] + ], ) return attack_graph @classmethod def load_from_file( - cls, - filename: str, - lang_graph: LanguageGraph, - model: Optional[Model] = None - ) -> AttackGraph: - """Create from json or yaml file depending on file extension""" + cls, filename: str, lang_graph: LanguageGraph, model: Model | None = None + ) -> AttackGraph: + """Create from json or yaml file depending on file extension.""" if model is not None: - logger.debug('Load attack graph from file "%s" with ' - 'model "%s".', filename, model.name) + logger.debug( + 'Load attack graph from file "%s" with model "%s".', + filename, + model.name, + ) else: - logger.debug('Load attack graph from file "%s" ' - 'without model.', filename) + logger.debug('Load attack graph from file "%s" without model.', filename) serialized_attack_graph = None if filename.endswith(('.yml', '.yaml')): serialized_attack_graph = load_dict_from_yaml_file(filename) elif filename.endswith('.json'): serialized_attack_graph = load_dict_from_json_file(filename) else: - raise ValueError('Unknown file extension, expected json/yml/yaml') - return cls._from_dict(serialized_attack_graph, - lang_graph, model = model) + msg = 'Unknown file extension, expected json/yml/yaml' + raise ValueError(msg) + return cls._from_dict(serialized_attack_graph, lang_graph, model=model) - def get_node_by_id(self, node_id: int) -> Optional[AttackGraphNode]: - """ - Return the attack node that matches the id provided. + def get_node_by_id(self, node_id: int) -> AttackGraphNode | None: + """Return the attack node that matches the id provided. Arguments: node_id - the id of the attack graph node we are looking for Return: The attack step node that matches the given id. - """ + """ logger.debug('Looking up node with id %s', node_id) return self._id_to_node.get(node_id) - def get_node_by_full_name(self, full_name: str) -> Optional[AttackGraphNode]: - """ - Return the attack node that matches the full name provided. + def get_node_by_full_name(self, full_name: str) -> AttackGraphNode | None: + """Return the attack node that matches the full name provided. Arguments: full_name - the full name of the attack graph node we are looking @@ -279,63 +291,55 @@ def get_node_by_full_name(self, full_name: str) -> Optional[AttackGraphNode]: Return: The attack step node that matches the given full name. - """ - logger.debug(f'Looking up node with full name "%s"', full_name) + """ + logger.debug('Looking up node with full name "%s"', full_name) return self._full_name_to_node.get(full_name) - def get_attacker_by_id(self, attacker_id: int) -> Optional[Attacker]: - """ - Return the attacker that matches the id provided. + def get_attacker_by_id(self, attacker_id: int) -> Attacker | None: + """Return the attacker that matches the id provided. Arguments: attacker_id - the id of the attacker we are looking for Return: The attacker that matches the given id. - """ + """ logger.debug(f'Looking up attacker with id {attacker_id}') return self._id_to_attacker.get(attacker_id) def attach_attackers(self) -> None: - """ - Create attackers and their entry point nodes and attach them to the + """Create attackers and their entry point nodes and attach them to the relevant attack step nodes and to the attackers. """ - if not self.model: - msg = "Can not attach attackers without a model" + msg = 'Can not attach attackers without a model' logger.error(msg) raise AttackGraphException(msg) - logger.info( - 'Attach attackers from "%s" model to the graph.', self.model.name - ) + logger.info('Attach attackers from "%s" model to the graph.', self.model.name) for attacker_info in self.model.attackers: - if not attacker_info.name: - msg = "Can not attach attacker without name" + msg = 'Can not attach attacker without name' logger.error(msg) raise AttackGraphException(msg) attacker = Attacker( - name = attacker_info.name, - entry_points = [], - reached_attack_steps = [] + name=attacker_info.name, entry_points=[], reached_attack_steps=[] ) self.add_attacker(attacker) - for (asset, attack_steps) in attacker_info.entry_points: + for asset, attack_steps in attacker_info.entry_points: for attack_step in attack_steps: full_name = asset.name + ':' + attack_step ag_node = self.get_node_by_full_name(full_name) if not ag_node: logger.warning( - 'Failed to find attacker entry point ' - '%s for %s.', - full_name, attacker.name + 'Failed to find attacker entry point %s for %s.', + full_name, + attacker.name, ) continue attacker.compromise(ag_node) @@ -343,13 +347,9 @@ def attach_attackers(self) -> None: attacker.entry_points = list(attacker.reached_attack_steps) def _follow_expr_chain( - self, - model: Model, - target_assets: set[Any], - expr_chain: Optional[ExpressionsChain] - ) -> set[Any]: - """ - Recursively follow a language graph expressions chain on an instance + self, model: Model, target_assets: set[Any], expr_chain: ExpressionsChain | None + ) -> set[Any]: + """Recursively follow a language graph expressions chain on an instance model. Arguments: @@ -362,8 +362,8 @@ def _follow_expr_chain( Return: A list of all of the target assets. - """ + """ if expr_chain is None: # There is no expressions chain link left to follow return the # current target assets @@ -373,31 +373,33 @@ def _follow_expr_chain( # Avoid running json.dumps when not in debug logger.debug( 'Following Expressions Chain:\n%s', - json.dumps(expr_chain.to_dict(), indent = 2) + json.dumps(expr_chain.to_dict(), indent=2), ) - match (expr_chain.type): + match expr_chain.type: case 'union' | 'intersection' | 'difference': # The set operators are used to combine the left hand and # right hand targets accordingly. if not expr_chain.left_link: - raise LanguageGraphException('"%s" step expression chain' - ' is missing the left link.' % expr_chain.type) + msg = ( + f'"{expr_chain.type}" step expression chain' + ' is missing the left link.' + ) + raise LanguageGraphException(msg) if not expr_chain.right_link: - raise LanguageGraphException('"%s" step expression chain' - ' is missing the right link.' % expr_chain.type) + msg = ( + f'"{expr_chain.type}" step expression chain' + ' is missing the right link.' + ) + raise LanguageGraphException(msg) lh_targets = self._follow_expr_chain( - model, - target_assets, - expr_chain.left_link + model, target_assets, expr_chain.left_link ) rh_targets = self._follow_expr_chain( - model, - target_assets, - expr_chain.right_link + model, target_assets, expr_chain.right_link ) - match (expr_chain.type): + match expr_chain.type: # Once the assets become hashable set operations should be # used instead. case 'union': @@ -415,8 +417,8 @@ def _follow_expr_chain( # Change the target assets from the current ones to the # associated assets given the specified field name. if not expr_chain.fieldname: - raise LanguageGraphException('"field" step expression ' - 'chain is missing fieldname.') + msg = '"field" step expression chain is missing fieldname.' + raise LanguageGraphException(msg) new_target_assets = set() new_target_assets.update( *( @@ -424,14 +426,14 @@ def _follow_expr_chain( asset, expr_chain.fieldname ) for asset in target_assets - ) + ) ) return new_target_assets case 'transitive': if not expr_chain.sub_link: - raise LanguageGraphException('"transitive" step ' - 'expression chain is missing sub link.') + msg = '"transitive" step expression chain is missing sub link.' + raise LanguageGraphException(msg) new_assets = target_assets @@ -447,110 +449,94 @@ def _follow_expr_chain( case 'subType': if not expr_chain.sub_link: - raise LanguageGraphException('"subType" step ' - 'expression chain is missing sub link.') + msg = '"subType" step expression chain is missing sub link.' + raise LanguageGraphException(msg) new_target_assets = set() new_target_assets.update( - self._follow_expr_chain( - model, target_assets, expr_chain.sub_link - ) + self._follow_expr_chain(model, target_assets, expr_chain.sub_link) ) selected_new_target_assets = set() for asset in new_target_assets: lang_graph_asset = self.lang_graph.assets[asset.type] if not lang_graph_asset: - raise LookupError( - f'Failed to find asset \"{asset.type}\" in the ' + msg = ( + f'Failed to find asset "{asset.type}" in the ' 'language graph.' ) + raise LookupError(msg) lang_graph_subtype_asset = expr_chain.subtype if not lang_graph_subtype_asset: - raise LookupError( - 'Failed to find asset "%s" in the ' - 'language graph.' % expr_chain.subtype + msg = ( + f'Failed to find asset "{expr_chain.subtype}" in the ' + 'language graph.' ) - if lang_graph_asset.is_subasset_of( - lang_graph_subtype_asset): + raise LookupError(msg) + if lang_graph_asset.is_subasset_of(lang_graph_subtype_asset): selected_new_target_assets.add(asset) return selected_new_target_assets case 'collect': if not expr_chain.left_link: - raise LanguageGraphException('"collect" step expression chain' - ' is missing the left link.') + msg = '"collect" step expression chain is missing the left link.' + raise LanguageGraphException(msg) if not expr_chain.right_link: - raise LanguageGraphException('"collect" step expression chain' - ' is missing the right link.') + msg = '"collect" step expression chain is missing the right link.' + raise LanguageGraphException(msg) lh_targets = self._follow_expr_chain( - model, - target_assets, - expr_chain.left_link + model, target_assets, expr_chain.left_link ) - rh_targets = self._follow_expr_chain( - model, - lh_targets, - expr_chain.right_link - ) - return rh_targets + return self._follow_expr_chain(model, lh_targets, expr_chain.right_link) case _: msg = 'Unknown attack expressions chain type: %s' - logger.error( - msg, - expr_chain.type - ) - raise AttackGraphStepExpressionError( - msg % expr_chain.type - ) + logger.error(msg, expr_chain.type) + raise AttackGraphStepExpressionError(msg % expr_chain.type) return None def _generate_graph(self) -> None: - """ - Generate the attack graph based on the original model instance and the + """Generate the attack graph based on the original model instance and the MAL language specification provided at initialization. """ - if not self.model: - msg = "Can not generate AttackGraph without model" + msg = 'Can not generate AttackGraph without model' logger.error(msg) raise AttackGraphException(msg) # First, generate all of the nodes of the attack graph. for asset in self.model.assets: - logger.debug( 'Generating attack steps for asset %s which is of class %s.', - asset.name, asset.type + asset.name, + asset.type, ) attack_step_nodes = [] lang_graph_asset = self.lang_graph.assets[asset.type] if lang_graph_asset is None: - raise LookupError( - f'Failed to find asset with name \"{asset.type}\" in ' + msg = ( + f'Failed to find asset with name "{asset.type}" in ' 'the language graph.' ) + raise LookupError(msg) for attack_step in lang_graph_asset.attack_steps.values(): - logger.debug( - 'Generating attack step node for %s.', attack_step.name - ) + logger.debug('Generating attack step node for %s.', attack_step.name) defense_status = None existence_status = None node_name = asset.name + ':' + attack_step.name - match (attack_step.type): + match attack_step.type: case 'defense': # Set the defense status for defenses defense_status = getattr(asset, attack_step.name) logger.debug( - 'Setting the defense status of \"%s\" to ' - '\"%s\".', - node_name, defense_status + 'Setting the defense status of "%s" to "%s".', + node_name, + defense_status, ) case 'exist' | 'notExist': @@ -559,10 +545,8 @@ def _generate_graph(self) -> None: existence_status = False for requirement in attack_step.requires: target_assets = self._follow_expr_chain( - self.model, - set([asset]), - requirement - ) + self.model, {asset}, requirement + ) # If the step expression resolution yielded # the target assets then the required assets # exist in the model. @@ -574,19 +558,19 @@ def _generate_graph(self) -> None: pass ag_node = AttackGraphNode( - type = attack_step.type, - lang_graph_attack_step = attack_step, - asset = asset, - name = attack_step.name, - ttc = attack_step.ttc, - children = [], - parents = [], - defense_status = defense_status, - existence_status = existence_status, - is_viable = True, - is_necessary = True, - tags = set(attack_step.tags), - compromised_by = [] + type=attack_step.type, + lang_graph_attack_step=attack_step, + asset=asset, + name=attack_step.name, + ttc=attack_step.ttc, + children=[], + parents=[], + defense_status=defense_status, + existence_status=existence_status, + is_viable=True, + is_necessary=True, + tags=set(attack_step.tags), + compromised_by=[], ) attack_step_nodes.append(ag_node) self.add_node(ag_node) @@ -597,60 +581,63 @@ def _generate_graph(self) -> None: logger.debug( 'Determining children for attack step "%s"(%d)', ag_node.full_name, - ag_node.id + ag_node.id, ) if not ag_node.asset: - raise AttackGraphException('Attack graph node is missing ' - 'asset link') + msg = 'Attack graph node is missing asset link' + raise AttackGraphException(msg) lang_graph_asset = self.lang_graph.assets[ag_node.asset.type] - lang_graph_attack_step = lang_graph_asset.attack_steps[\ - ag_node.name] + lang_graph_attack_step = lang_graph_asset.attack_steps[ag_node.name] while lang_graph_attack_step: for child in lang_graph_attack_step.children.values(): for target_attack_step, expr_chain in child: target_assets = self._follow_expr_chain( - self.model, - set([ag_node.asset]), - expr_chain + self.model, {ag_node.asset}, expr_chain ) for target_asset in target_assets: if target_asset is not None: - target_node_full_name = target_asset.name + \ - ':' + target_attack_step.name + target_node_full_name = ( + target_asset.name + ':' + target_attack_step.name + ) target_node = self.get_node_by_full_name( - target_node_full_name) + target_node_full_name + ) if target_node is None: - msg = ('Failed to find target node ' - '"%s" to link with for attack ' - 'step "%s"(%d)!') + msg = ( + 'Failed to find target node ' + '"%s" to link with for attack ' + 'step "%s"(%d)!' + ) logger.error( msg, target_node_full_name, ag_node.full_name, - ag_node.id + ag_node.id, ) raise AttackGraphStepExpressionError( - msg % ( + msg + % ( target_node_full_name, ag_node.full_name, - ag_node.id + ag_node.id, ) ) assert ag_node.id is not None assert target_node.id is not None - logger.debug('Linking attack step "%s"(%d) ' - 'to attack step "%s"(%d)' % - ( + logger.debug( + 'Linking attack step "%s"(%d) ' + 'to attack step "%s"(%d)' + % ( ag_node.full_name, ag_node.id, target_node.full_name, - target_node.id + target_node.id, ) ) ag_node.children.append(target_node) @@ -659,36 +646,31 @@ def _generate_graph(self) -> None: break lang_graph_attack_step = lang_graph_attack_step.inherits - def regenerate_graph(self) -> None: - """ - Regenerate the attack graph based on the original model instance and + """Regenerate the attack graph based on the original model instance and the MAL language specification provided at initialization. """ - self.nodes = [] self.attackers = [] self._generate_graph() - def add_node( - self, - node: AttackGraphNode, - node_id: Optional[int] = None - ) -> None: + def add_node(self, node: AttackGraphNode, node_id: int | None = None) -> None: """Add a node to the graph Arguments: node - the node to add node_id - the id to assign to this node, usually used when loading - an attack graph from a file + an attack graph from a file. """ if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug - logger.debug(f'Add node \"{node.full_name}\" ' - f'with id:{node_id}:\n' \ - + json.dumps(node.to_dict(), indent = 2)) + logger.debug( + f'Add node "{node.full_name}" ' + f'with id:{node_id}:\n' + json.dumps(node.to_dict(), indent=2) + ) if node.id in self._id_to_node: - raise ValueError(f'Node index {node_id} already in use.') + msg = f'Node index {node_id} already in use.' + raise ValueError(msg) node.id = node_id if node_id is not None else self.next_node_id self.next_node_id = max(node.id + 1, self.next_node_id) @@ -700,11 +682,11 @@ def add_node( def remove_node(self, node: AttackGraphNode) -> None: """Remove node from attack graph Arguments: - node - the node we wish to remove from the attack graph + node - the node we wish to remove from the attack graph. """ if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug - logger.debug(f'Remove node "%s"(%d).', node.full_name, node.id) + logger.debug('Remove node "%s"(%d).', node.full_name, node.id) for child in node.children: child.parents.remove(node) for parent in node.parents: @@ -712,17 +694,18 @@ def remove_node(self, node: AttackGraphNode) -> None: self.nodes.remove(node) if not isinstance(node.id, int): - raise ValueError(f'Invalid node id.') + msg = 'Invalid node id.' + raise ValueError(msg) del self._id_to_node[node.id] del self._full_name_to_node[node.full_name] def add_attacker( - self, - attacker: Attacker, - attacker_id: Optional[int] = None, - entry_points: list[int] = [], - reached_attack_steps: list[int] = [] - ): + self, + attacker: Attacker, + attacker_id: int | None = None, + entry_points: list[int] | None = None, + reached_attack_steps: list[int] | None = None, + ) -> None: """Add an attacker to the graph Arguments: attacker - the attacker to add @@ -732,24 +715,25 @@ def add_attacker( entry_points - list of attack step ids that serve as entry points for the attacker reached_attack_steps - list of ids of the attack steps that the - attacker has reached + attacker has reached. """ - + if reached_attack_steps is None: + reached_attack_steps = [] + if entry_points is None: + entry_points = [] if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug if attacker_id is not None: - logger.debug('Add attacker "%s" with id:%d.', - attacker.name, - attacker_id + logger.debug( + 'Add attacker "%s" with id:%d.', attacker.name, attacker_id ) else: - logger.debug('Add attacker "%s" without id.', - attacker.name - ) + logger.debug('Add attacker "%s" without id.', attacker.name) attacker.id = attacker_id or self.next_attacker_id if attacker.id in self._id_to_attacker: - raise ValueError(f'Attacker index {attacker_id} already in use.') + msg = f'Attacker index {attacker_id} already in use.' + raise ValueError(msg) self.next_attacker_id = max(attacker.id + 1, self.next_attacker_id) for node_id in reached_attack_steps: @@ -757,8 +741,7 @@ def add_attacker( if node: attacker.compromise(node) else: - msg = ("Could not find node with id %d" - "in reached attack steps.") + msg = 'Could not find node with id %din reached attack steps.' logger.error(msg, node_id) raise AttackGraphException(msg % node_id) for node_id in entry_points: @@ -766,26 +749,24 @@ def add_attacker( if node: attacker.entry_points.append(node) else: - msg = ("Could not find node with id %d" - "in attacker entrypoints.") + msg = 'Could not find node with id %din attacker entrypoints.' logger.error(msg, node_id) raise AttackGraphException(msg % node_id) self.attackers.append(attacker) self._id_to_attacker[attacker.id] = attacker - def remove_attacker(self, attacker: Attacker): + def remove_attacker(self, attacker: Attacker) -> None: """Remove attacker from attack graph Arguments: - attacker - the attacker we wish to remove from the attack graph + attacker - the attacker we wish to remove from the attack graph. """ if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug - logger.debug('Remove attacker "%s" with id:%d.', - attacker.name, - attacker.id) + logger.debug('Remove attacker "%s" with id:%d.', attacker.name, attacker.id) for node in attacker.reached_attack_steps: attacker.undo_compromise(node) self.attackers.remove(attacker) if not isinstance(attacker.id, int): - raise ValueError(f'Invalid attacker id.') + msg = 'Invalid attacker id.' + raise ValueError(msg) del self._id_to_attacker[attacker.id] diff --git a/maltoolbox/attackgraph/node.py b/maltoolbox/attackgraph/node.py index a5884172..f6cd6aed 100644 --- a/maltoolbox/attackgraph/node.py +++ b/maltoolbox/attackgraph/node.py @@ -1,41 +1,43 @@ -""" -MAL-Toolbox Attack Graph Node Dataclass -""" +"""MAL-Toolbox Attack Graph Node Dataclass.""" from __future__ import annotations + import copy -from dataclasses import field, dataclass +from dataclasses import dataclass, field from functools import cached_property -from typing import Any, Optional +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from maltoolbox.language import LanguageGraphAttackStep + + from . import Attacker -from . import Attacker -from ..language import LanguageGraphAttackStep @dataclass class AttackGraphNode: - """Node part of AttackGraph""" + """Node part of AttackGraph.""" + type: str lang_graph_attack_step: LanguageGraphAttackStep name: str - ttc: Optional[dict] = None - id: Optional[int] = None - asset: Optional[Any] = None + ttc: dict | None = None + id: int | None = None + asset: Any | None = None children: list[AttackGraphNode] = field(default_factory=list) parents: list[AttackGraphNode] = field(default_factory=list) - defense_status: Optional[float] = None - existence_status: Optional[bool] = None + defense_status: float | None = None + existence_status: bool | None = None is_viable: bool = True is_necessary: bool = True compromised_by: list[Attacker] = field(default_factory=list) tags: set[str] = field(default_factory=set) - attributes: Optional[dict] = None + attributes: dict | None = None # Optional extra metadata for AttackGraphNode extras: dict = field(default_factory=dict) - def to_dict(self) -> dict: - """Convert node to dictionary""" + """Convert node to dictionary.""" node_dict: dict = { 'id': self.id, 'type': self.type, @@ -44,8 +46,7 @@ def to_dict(self) -> dict: 'ttc': self.ttc, 'children': {}, 'parents': {}, - 'compromised_by': [attacker.name for attacker in \ - self.compromised_by] + 'compromised_by': [attacker.name for attacker in self.compromised_by], } for child in self.children: @@ -69,13 +70,11 @@ def to_dict(self) -> dict: return node_dict - def __repr__(self) -> str: return str(self.to_dict()) - def __deepcopy__(self, memo) -> AttackGraphNode: - """Deep copy an attackgraph node + """Deep copy an attackgraph node. The deepcopy will copy over node specific information, such as type, name, etc., but it will not copy attack graph relations such as @@ -83,7 +82,6 @@ def __deepcopy__(self, memo) -> AttackGraphNode: should be recreated when deepcopying the attack graph itself. """ - # Check if the object is already in the memo dictionary if id(self) in memo: return memo[id(self)] @@ -104,7 +102,7 @@ def __deepcopy__(self, memo) -> AttackGraphNode: [], set(), {}, - {} + {}, ) copied_node.tags = copy.deepcopy(self.tags, memo) @@ -116,74 +114,67 @@ def __deepcopy__(self, memo) -> AttackGraphNode: return copied_node - def is_compromised(self) -> bool: - """ - Return True if any attackers have compromised this node. + """Return True if any attackers have compromised this node. False, otherwise. """ return len(self.compromised_by) > 0 - def is_compromised_by(self, attacker: Attacker) -> bool: - """ - Return True if the attacker given as an argument has compromised this + """Return True if the attacker given as an argument has compromised this node. False, otherwise. Arguments: attacker - the attacker we are interested in + """ return attacker in self.compromised_by - def compromise(self, attacker: Attacker) -> None: - """ - Have the attacker given as a parameter compromise this node. + """Have the attacker given as a parameter compromise this node. Arguments: attacker - the attacker that will compromise the node + """ attacker.compromise(self) - def undo_compromise(self, attacker: Attacker) -> None: - """ - Remove the attacker given as a parameter from the list of attackers + """Remove the attacker given as a parameter from the list of attackers that have compromised this node. Arguments: attacker - the attacker that we wish to remove from the compromised list. + """ attacker.undo_compromise(self) - def is_enabled_defense(self) -> bool: - """ - Return True if this node is a defense node and it is enabled and not + """Return True if this node is a defense node and it is enabled and not suppressed via tags. False, otherwise. """ - return self.type == 'defense' and \ - 'suppress' not in self.tags and \ - self.defense_status == 1.0 - + return ( + self.type == 'defense' + and 'suppress' not in self.tags + and self.defense_status == 1.0 + ) def is_available_defense(self) -> bool: - """ - Return True if this node is a defense node and it is not fully enabled + """Return True if this node is a defense node and it is not fully enabled and not suppressed via tags. False otherwise. """ - return self.type == 'defense' and \ - 'suppress' not in self.tags and \ - self.defense_status != 1.0 - + return ( + self.type == 'defense' + and 'suppress' not in self.tags + and self.defense_status != 1.0 + ) @property def full_name(self) -> str: - """ - Return the full name of the attack step. This is a combination of the + """Return the full name of the attack step. This is a combination of the asset name to which the attack step belongs and attack step name itself. """ @@ -193,7 +184,6 @@ def full_name(self) -> str: full_name = str(self.id) + ':' + self.name return full_name - @cached_property def info(self) -> dict[str, str]: return self.lang_graph_attack_step.info diff --git a/maltoolbox/attackgraph/query.py b/maltoolbox/attackgraph/query.py index 1f9ce11f..5a5ba7ef 100644 --- a/maltoolbox/attackgraph/query.py +++ b/maltoolbox/attackgraph/query.py @@ -1,40 +1,36 @@ -""" -MAL-Toolbox Attack Graph Query Submodule +"""MAL-Toolbox Attack Graph Query Submodule. This submodule contains functions that analyze the information present in the attack graph, but do not alter the structure or nodes in any way. """ + from __future__ import annotations + import logging from typing import TYPE_CHECKING -from .attackgraph import AttackGraph, Attacker - if TYPE_CHECKING: - from .attackgraph import AttackGraphNode + from .attackgraph import Attacker, AttackGraph, AttackGraphNode logger = logging.getLogger(__name__) -def is_node_traversable_by_attacker( - node: AttackGraphNode, attacker: Attacker - ) -> bool: - """ - Return True or False depending if the node specified is traversable + +def is_node_traversable_by_attacker(node: AttackGraphNode, attacker: Attacker) -> bool: + """Return True or False depending if the node specified is traversable for the attacker given. Arguments: node - the node we wish to evalute attacker - the attacker whose traversability we are interested in - """ + """ logger.debug( - 'Evaluate if "%s"(%d), of type "%s", is traversable by Attacker ' - '"%s"(%d)', + 'Evaluate if "%s"(%d), of type "%s", is traversable by Attacker "%s"(%d)', node.full_name, node.id, node.type, attacker.name, - attacker.id + attacker.id, ) if not node.is_viable: logger.debug( @@ -44,20 +40,18 @@ def is_node_traversable_by_attacker( ) return False - match(node.type): + match node.type: case 'or': logger.debug( - '"%s"(%d) is traversable because it is viable and ' - 'of type "or".', + '"%s"(%d) is traversable because it is viable and of type "or".', node.full_name, - node.id + node.id, ) return True case 'and': for parent in node.parents: - if parent.is_necessary and \ - not parent.is_compromised_by(attacker): + if parent.is_necessary and not parent.is_compromised_by(attacker): # If the parent is not present in the attacks steps # already reached and is necessary. logger.debug( @@ -67,7 +61,7 @@ def is_node_traversable_by_attacker( node.full_name, node.id, parent.full_name, - parent.id + parent.id, ) return False logger.debug( @@ -75,7 +69,7 @@ def is_node_traversable_by_attacker( 'of type "and", and all of its necessary parents have ' 'already been compromised.', node.full_name, - node.id + node.id, ) return True @@ -87,7 +81,7 @@ def is_node_traversable_by_attacker( 'checked for traversability.', node.full_name, node.id, - node.type + node.type, ) return False @@ -96,50 +90,48 @@ def is_node_traversable_by_attacker( 'Node "%s"(%d) has an unknown type "%s".', node.full_name, node.id, - node.type + node.type, ) return False -def get_attack_surface( - attacker: Attacker - ) -> list[AttackGraphNode]: - """ - Get the current attack surface of an attacker. This includes all of the + +def get_attack_surface(attacker: Attacker) -> list[AttackGraphNode]: + """Get the current attack surface of an attacker. This includes all of the viable children nodes of already reached attack steps that are of 'or' type and the 'and' type children nodes which have all of their necessary parents in the attack steps reached. Arguments: attacker - the Attacker whose attack surface is sought + """ logger.debug( - 'Get the attack surface for Attacker "%s"(%d).', - attacker.name, - attacker.id + 'Get the attack surface for Attacker "%s"(%d).', attacker.name, attacker.id ) attack_surface = [] for attack_step in attacker.reached_attack_steps: logger.debug( - 'Determine attack surface stemming from ' - '"%s"(%d) for Attacker "%s"(%d).', + 'Determine attack surface stemming from "%s"(%d) for Attacker "%s"(%d).', attack_step.full_name, attack_step.id, attacker.name, - attacker.id + attacker.id, ) for child in attack_step.children: - if is_node_traversable_by_attacker(child, attacker) and \ - child not in attack_surface: + if ( + is_node_traversable_by_attacker(child, attacker) + and child not in attack_surface + ): attack_surface.append(child) return attack_surface + def update_attack_surface_add_nodes( - attacker: Attacker, - current_attack_surface: list[AttackGraphNode], - nodes: list[AttackGraphNode] - ) -> list[AttackGraphNode]: - """ - Update the attack surface of an attacker with the new attack step nodes + attacker: Attacker, + current_attack_surface: list[AttackGraphNode], + nodes: list[AttackGraphNode], +) -> list[AttackGraphNode]: + """Update the attack surface of an attacker with the new attack step nodes provided to see if any of their children can be added. Arguments: @@ -149,52 +141,53 @@ def update_attack_surface_add_nodes( nodes - the newly compromised attack step nodes that we wish to see if any of their children should be added to the attack surface + """ - logger.debug('Update the attack surface for Attacker "%s"(%d).', - attacker.name, - attacker.id) + logger.debug( + 'Update the attack surface for Attacker "%s"(%d).', attacker.name, attacker.id + ) attack_surface = current_attack_surface for attack_step in nodes: logger.debug( - 'Determine attack surface stemming from "%s"(%d) ' - 'for Attacker "%s"(%d).', + 'Determine attack surface stemming from "%s"(%d) for Attacker "%s"(%d).', attack_step.full_name, attack_step.id, attacker.name, - attacker.id + attacker.id, ) for child in attack_step.children: is_traversable = is_node_traversable_by_attacker(child, attacker) if is_traversable and child not in attack_surface: logger.debug( - 'Add node "%s"(%d) to the attack surface of ' - 'Attacker "%s"(%d).', + 'Add node "%s"(%d) to the attack surface of Attacker "%s"(%d).', child.full_name, child.id, attacker.name, - attacker.id + attacker.id, ) attack_surface.append(child) return attack_surface + def get_defense_surface(graph: AttackGraph) -> list[AttackGraphNode]: - """ - Get the defense surface. All non-suppressed defense steps that are not + """Get the defense surface. All non-suppressed defense steps that are not already fully enabled. Arguments: graph - the attack graph + """ logger.debug('Get the defense surface.') return [node for node in graph.nodes if node.is_available_defense()] + def get_enabled_defenses(graph: AttackGraph) -> list[AttackGraphNode]: - """ - Get the defenses already enabled. All non-suppressed defense steps that + """Get the defenses already enabled. All non-suppressed defense steps that are already fully enabled. Arguments: graph - the attack graph + """ logger.debug('Get the enabled defenses.') return [node for node in graph.nodes if node.is_enabled_defense()] diff --git a/maltoolbox/exceptions.py b/maltoolbox/exceptions.py index 9012431f..ca61fa57 100644 --- a/maltoolbox/exceptions.py +++ b/maltoolbox/exceptions.py @@ -1,14 +1,14 @@ class MalToolboxException(Exception): """Base exception for all other maltoolbox exceptions to inherit from.""" - pass + class LanguageGraphException(MalToolboxException): """Base exception for all language-graph related exceptions.""" - pass + class LanguageGraphSuperAssetNotFoundError(LanguageGraphException): """Asset's super asset not found in language graph during attack graph construction.""" - pass + class LanguageGraphAssociationError(LanguageGraphException): """Error in building an association. @@ -16,30 +16,27 @@ class LanguageGraphAssociationError(LanguageGraphException): For example, right or left-hand side asset of association missing in language graph. """ - pass + class LanguageGraphStepExpressionError(LanguageGraphException): """A target asset cannot be linked with for a step expression.""" - pass + class AttackGraphException(MalToolboxException): """Base exception for all attack-graph related exceptions.""" - pass + class AttackGraphStepExpressionError(AttackGraphException): """A target attack step cannot be linked with for a step expression.""" - pass class ModelException(MalToolboxException): - """Base Exception for all Model related exceptions""" - pass + """Base Exception for all Model related exceptions.""" class ModelAssociationException(ModelException): - """Exception related to associations in Model""" - pass + """Exception related to associations in Model.""" + class DuplicateModelAssociationError(ModelException): - """Associations should be unique as part of Model""" - pass + """Associations should be unique as part of Model.""" diff --git a/maltoolbox/file_utils.py b/maltoolbox/file_utils.py index 35f3718e..ebcaa674 100644 --- a/maltoolbox/file_utils.py +++ b/maltoolbox/file_utils.py @@ -1,17 +1,19 @@ -"""Utily functions for file handling""" +"""Utily functions for file handling.""" import json + import yaml from python_jsonschema_objects.literals import LiteralValue + def save_dict_to_json_file(filename: str, serialized_object: dict) -> None: """Save serialized object to a json file. Arguments: filename - the name of the output file data - dict to output as json - """ + """ with open(filename, 'w', encoding='utf-8') as f: json.dump(serialized_object, f, indent=4) @@ -22,17 +24,18 @@ def save_dict_to_yaml_file(filename: str, serialized_object: dict) -> None: Arguments: filename - the name of the output file data - dict to output as yaml + """ class NoAliasSafeDumper(yaml.SafeDumper): - def ignore_aliases(self, data): + def ignore_aliases(self, data) -> bool: return True # Handle Literal values from jsonschema_objects yaml.add_multi_representer( LiteralValue, lambda dumper, data: dumper.represent_data(data._value), - NoAliasSafeDumper + NoAliasSafeDumper, ) with open(filename, 'w', encoding='utf-8') as f: @@ -40,17 +43,15 @@ def ignore_aliases(self, data): def load_dict_from_yaml_file(filename: str) -> dict: - """Open json file and read as dict""" - with open(filename, 'r', encoding='utf-8') as file: - object_dict = yaml.safe_load(file) - return object_dict + """Open json file and read as dict.""" + with open(filename, encoding='utf-8') as file: + return yaml.safe_load(file) def load_dict_from_json_file(filename: str) -> dict: - """Open yaml file and read as dict""" - with open(filename, 'r', encoding='utf-8') as file: - object_dict = json.loads(file.read()) - return object_dict + """Open yaml file and read as dict.""" + with open(filename, encoding='utf-8') as file: + return json.loads(file.read()) def save_dict_to_file(filename: str, dictionary: dict) -> None: @@ -60,11 +61,12 @@ def save_dict_to_file(filename: str, dictionary: dict) -> None: Arguments: filename - the name of the output file dictionary - the dict to save to the file - """ + """ if filename.endswith(('.yml', '.yaml')): save_dict_to_yaml_file(filename, dictionary) elif filename.endswith('.json'): save_dict_to_json_file(filename, dictionary) else: - raise ValueError('Unknown file extension, expected json/yml/yaml') + msg = 'Unknown file extension, expected json/yml/yaml' + raise ValueError(msg) diff --git a/maltoolbox/ingestors/neo4j.py b/maltoolbox/ingestors/neo4j.py index 4e6a954f..d4d0d010 100644 --- a/maltoolbox/ingestors/neo4j.py +++ b/maltoolbox/ingestors/neo4j.py @@ -1,26 +1,20 @@ -""" -MAL-Toolbox Neo4j Ingestor Module -""" +"""MAL-Toolbox Neo4j Ingestor Module.""" # mypy: ignore-errors import logging from py2neo import Graph, Node, Relationship, Subgraph -from ..model import AttackerAttachment, Model -from ..language import LanguageGraph, LanguageClassesFactory +from maltoolbox.language import LanguageClassesFactory, LanguageGraph +from maltoolbox.model import AttackerAttachment, Model logger = logging.getLogger(__name__) -def ingest_attack_graph(graph, - uri: str, - username: str, - password: str, - dbname: str, - delete: bool = False - ) -> None: - """ - Ingest an attack graph into a neo4j database + +def ingest_attack_graph( + graph, uri: str, username: str, password: str, dbname: str, delete: bool = False +) -> None: + """Ingest an attack graph into a neo4j database. Arguments: graph - the attackgraph provided by the atkgraph.py module. @@ -30,8 +24,8 @@ def ingest_attack_graph(graph, dbname - the selected database delete - if True, the previous content of the database is deleted before ingesting the new attack graph - """ + """ g = Graph(uri=uri, user=username, password=password, name=dbname) if delete: g.delete_all() @@ -42,20 +36,20 @@ def ingest_attack_graph(graph, node_dict = node.to_dict() nodes[node.id] = Node( node_dict['asset'] if 'asset' in node_dict else node_dict['id'], - name = node_dict['name'], - full_name = node.full_name, - type = node_dict['type'], - ttc = str(node_dict['ttc']), - is_necessary = str(node.is_necessary), - is_viable = str(node.is_viable), - compromised_by = str(node_dict['compromised_by']), - defense_status = node_dict['defense_status'] if 'defense_status' - in node_dict else 'N/A') - + name=node_dict['name'], + full_name=node.full_name, + type=node_dict['type'], + ttc=str(node_dict['ttc']), + is_necessary=str(node.is_necessary), + is_viable=str(node.is_viable), + compromised_by=str(node_dict['compromised_by']), + defense_status=node_dict.get('defense_status', 'N/A'), + ) for node in graph.nodes: - for child in node.children: - rels.append(Relationship(nodes[node.id], nodes[child.id])) + rels.extend( + Relationship(nodes[node.id], nodes[child.id]) for child in node.children + ) subgraph = Subgraph(list(nodes.values()), rels) @@ -64,15 +58,10 @@ def ingest_attack_graph(graph, g.commit(tx) -def ingest_model(model, - uri: str, - username: str, - password: str, - dbname: str, - delete: bool = False - ) -> None: - """ - Ingest an instance model graph into a Neo4J database +def ingest_model( + model, uri: str, username: str, password: str, dbname: str, delete: bool = False +) -> None: + """Ingest an instance model graph into a Neo4J database. Arguments: model - the instance model dictionary as provided by the model.py module @@ -82,6 +71,7 @@ def ingest_model(model, dbname - the selected database delete - if True, the previous content of the database is deleted before ingesting the new attack graph + """ g = Graph(uri=uri, user=username, password=password, name=dbname) if delete: @@ -91,11 +81,12 @@ def ingest_model(model, rels = [] for asset in model.assets: - - nodes[str(asset.id)] = Node(str(asset.type), - name=str(asset.name), - asset_id=str(asset.id), - type=str(asset.type)) + nodes[str(asset.id)] = Node( + str(asset.type), + name=str(asset.name), + asset_id=str(asset.id), + type=str(asset.type), + ) for assoc in model.associations: firstElementName, secondElementName = assoc._properties.keys() @@ -103,13 +94,20 @@ def ingest_model(model, secondElements = getattr(assoc, secondElementName) for first_asset in firstElements: for second_asset in secondElements: - rels.append(Relationship(nodes[str(first_asset.id)], - str(firstElementName), - nodes[str(second_asset.id)])) - rels.append(Relationship(nodes[str(second_asset.id)], - str(secondElementName), - nodes[str(first_asset.id)])) - + rels.extend( + ( + Relationship( + nodes[str(first_asset.id)], + str(firstElementName), + nodes[str(second_asset.id)], + ), + Relationship( + nodes[str(second_asset.id)], + str(secondElementName), + nodes[str(first_asset.id)], + ), + ) + ) subgraph = Subgraph(list(nodes.values()), rels) @@ -119,56 +117,61 @@ def ingest_model(model, def get_model( - uri: str, - username: str, - password: str, - dbname: str, - lang_graph: LanguageGraph, - lang_classes_factory: LanguageClassesFactory - ) -> Model: - """Load a model from Neo4j""" - + uri: str, + username: str, + password: str, + dbname: str, + lang_graph: LanguageGraph, + lang_classes_factory: LanguageClassesFactory, +) -> Model: + """Load a model from Neo4j.""" g = Graph(uri=uri, user=username, password=password, name=dbname) instance_model = Model('Neo4j imported model', lang_classes_factory) # Get all assets - assets_results = g.run('MATCH (a) WHERE a.type IS NOT NULL RETURN DISTINCT a').data() + assets_results = g.run( + 'MATCH (a) WHERE a.type IS NOT NULL RETURN DISTINCT a' + ).data() for asset in assets_results: asset_data = dict(asset['a']) - logger.debug( - 'Loading asset from Neo4j instance:\n%s', str(asset_data) - ) + logger.debug('Loading asset from Neo4j instance:\n%s', str(asset_data)) if asset_data['type'] == 'Attacker': attacker_id = int(asset_data['asset_id']) attacker = AttackerAttachment() attacker.entry_points = [] - instance_model.add_attacker(attacker, attacker_id = attacker_id) + instance_model.add_attacker(attacker, attacker_id=attacker_id) continue if not hasattr(lang_classes_factory.ns, asset_data['type']): msg = 'Failed to find %s asset in language specification!' - logger.error(msg, asset_data["type"]) - raise LookupError(msg % asset_data["type"]) + logger.error(msg, asset_data['type']) + raise LookupError(msg % asset_data['type']) - asset_obj = getattr(lang_classes_factory.ns, - asset_data['type'])(name = asset_data['name']) + asset_obj = getattr(lang_classes_factory.ns, asset_data['type'])( + name=asset_data['name'] + ) asset_id = int(asset_data['asset_id']) - #TODO Process defense values when they are included in Neo4j + # TODO Process defense values when they are included in Neo4j instance_model.add_asset(asset_obj, asset_id) # Get all relationships - assocs_results = g.run('MATCH (a)-[r1]->(b),(a)<-[r2]-(b) WHERE a.type IS NOT NULL RETURN DISTINCT a, r1, r2, b').data() + assocs_results = g.run( + 'MATCH (a)-[r1]->(b),(a)<-[r2]-(b) WHERE a.type IS NOT NULL RETURN DISTINCT a, r1, r2, b' + ).data() for assoc in assocs_results: - left_field = list(assoc['r1'].types())[0] - right_field = list(assoc['r2'].types())[0] + left_field = next(iter(assoc['r1'].types())) + right_field = next(iter(assoc['r2'].types())) left_asset = dict(assoc['a']) right_asset = dict(assoc['b']) logger.debug( 'Load association ("%s", "%s", "%s", "%s") from Neo4j instance.', - left_field, right_field, left_asset["type"], right_asset["type"] + left_field, + right_field, + left_asset['type'], + right_asset['type'], ) left_id = int(left_asset['asset_id']) @@ -195,8 +198,7 @@ def get_model( msg = 'Failed to find asset with id %d in model!' logger.error(msg, target_id) raise LookupError(msg % target_id) - attacker.entry_points.append((target_asset, - [target_prop])) + attacker.entry_points.append((target_asset, [target_prop])) continue left_asset = instance_model.get_asset_by_id(left_id) @@ -211,45 +213,42 @@ def get_model( raise LookupError(msg % right_id) assoc = lang_graph.get_association_by_fields_and_assets( - left_field, - right_field, - left_asset.type, - right_asset.type) + left_field, right_field, left_asset.type, right_asset.type + ) if not assoc: logger.error( 'Failed to find ("%s", "%s", "%s", "%s")' 'association in language specification!', - left_asset.type, right_asset.type, - left_field, right_field + left_asset.type, + right_asset.type, + left_field, + right_field, ) return None logger.debug('Found "%s" association.', assoc.name) assoc_name = lang_classes_factory.get_association_by_signature( - assoc.name, - left_asset.type, - right_asset.type + assoc.name, left_asset.type, right_asset.type ) if not assoc_name: - msg = 'Failed to find \"%s\" association in language specification!' + msg = 'Failed to find "%s" association in language specification!' logger.error(msg, assoc.name) raise LookupError(msg % assoc.name) assoc = getattr(lang_classes_factory.ns, assoc_name)() setattr(assoc, left_field, [left_asset]) setattr(assoc, right_field, [right_asset]) - if not (instance_model.association_exists_between_assets( - assoc_name, - left_asset, - right_asset - ) or instance_model.association_exists_between_assets( - assoc_name, - right_asset, - left_asset - )): + if not ( + instance_model.association_exists_between_assets( + assoc_name, left_asset, right_asset + ) + or instance_model.association_exists_between_assets( + assoc_name, right_asset, left_asset + ) + ): instance_model.add_association(assoc) return instance_model diff --git a/maltoolbox/language/__init__.py b/maltoolbox/language/__init__.py index 942d3363..9c68ad3f 100644 --- a/maltoolbox/language/__init__.py +++ b/maltoolbox/language/__init__.py @@ -1,8 +1,18 @@ -"""Contains tools to process MAL languages""" +"""Contains tools to process MAL languages.""" -from .languagegraph import (LanguageGraph, - ExpressionsChain, - LanguageGraphAsset, - LanguageGraphAttackStep, - disaggregate_attack_step_full_name) -from .classes_factory import LanguageClassesFactory +from .classes_factory import LanguageClassesFactory as LanguageClassesFactory +from .languagegraph import ( + ExpressionsChain as ExpressionsChain, +) +from .languagegraph import ( + LanguageGraph as LanguageGraph, +) +from .languagegraph import ( + LanguageGraphAsset as LanguageGraphAsset, +) +from .languagegraph import ( + LanguageGraphAttackStep as LanguageGraphAttackStep, +) +from .languagegraph import ( + disaggregate_attack_step_full_name as disaggregate_attack_step_full_name, +) diff --git a/maltoolbox/language/classes_factory.py b/maltoolbox/language/classes_factory.py index 9b58b65a..518ded0f 100644 --- a/maltoolbox/language/classes_factory.py +++ b/maltoolbox/language/classes_factory.py @@ -1,8 +1,9 @@ +"""MAL-Toolbox Language Classes Factory Module +Uses python_jsonschema_objects to generate python classes from a MAL language. """ -MAL-Toolbox Language Classes Factory Module -Uses python_jsonschema_objects to generate python classes from a MAL language -""" + from __future__ import annotations + import json import logging from typing import TYPE_CHECKING @@ -10,24 +11,25 @@ import python_jsonschema_objects as pjs if TYPE_CHECKING: - from typing import Literal, Optional, TypeAlias - from maltoolbox.language import LanguageGraph + from typing import Literal, TypeAlias + from python_jsonschema_objects.classbuilder import ProtocolBase + from maltoolbox.language import LanguageGraph + SchemaGeneratedClass: TypeAlias = ProtocolBase logger = logging.getLogger(__name__) + class LanguageClassesFactory: - def __init__(self, lang_graph: LanguageGraph): + def __init__(self, lang_graph: LanguageGraph) -> None: self.lang_graph: LanguageGraph = lang_graph self.json_schema: dict = {} self._create_classes() def _generate_assets(self) -> None: - """ - Generate JSON Schema for asset types in the language specification. - """ + """Generate JSON Schema for asset types in the language specification.""" for asset_name, asset in self.lang_graph.assets.items(): logger.debug('Creating %s asset JSON schema entry.', asset.name) asset_json_entry = { @@ -36,17 +38,18 @@ def _generate_assets(self) -> None: 'properties': {}, } asset_json_entry['properties']['id'] = { - 'type' : 'integer', + 'type': 'integer', + } + asset_json_entry['properties']['type'] = { + 'type': 'string', + 'default': asset_name, } - asset_json_entry['properties']['type'] = \ - { - 'type' : 'string', - 'default': asset_name - } if asset.own_super_asset: asset_json_entry['allOf'] = [ - {'$ref': '#/definitions/LanguageAsset/definitions/Asset_'\ - + asset.own_super_asset.name} + { + '$ref': '#/definitions/LanguageAsset/definitions/Asset_' + + asset.own_super_asset.name + } ] for step_name, step in asset.attack_steps.items(): if step.type == 'defense': @@ -54,119 +57,121 @@ def _generate_assets(self) -> None: default_defense_value = 1.0 else: default_defense_value = 0.0 - asset_json_entry['properties'][step_name] = \ - { - 'type' : 'number', - 'minimum' : 0, - 'maximum' : 1, - 'default': default_defense_value - } - self.json_schema['definitions']['LanguageAsset']['definitions']\ - ['Asset_' + asset_name] = asset_json_entry + asset_json_entry['properties'][step_name] = { + 'type': 'number', + 'minimum': 0, + 'maximum': 1, + 'default': default_defense_value, + } + self.json_schema['definitions']['LanguageAsset']['definitions'][ + 'Asset_' + asset_name + ] = asset_json_entry self.json_schema['definitions']['LanguageAsset']['oneOf'].append( {'$ref': '#/definitions/LanguageAsset/definitions/Asset_' + asset_name} ) def _generate_associations(self) -> None: - """ - Generate JSON Schema for association types in the language specification. - """ + """Generate JSON Schema for association types in the language specification.""" + def create_association_entry(assoc: SchemaGeneratedClass): logger.debug('Creating %s association JSON schema entry.', assoc.name) assoc_json_entry = { 'title': 'Association_' + assoc.full_name, 'type': 'object', - 'properties': {} + 'properties': {}, + } + assoc_json_entry['properties']['type'] = { + 'type': 'string', + 'default': assoc.name, } - assoc_json_entry['properties']['type'] = \ - { - 'type' : 'string', - 'default': assoc.name - } create_association_field(assoc, assoc_json_entry, 'left') create_association_field(assoc, assoc_json_entry, 'right') return assoc_json_entry def create_association_field( - assoc: SchemaGeneratedClass, - assoc_json_entry: dict, - position: Literal['left', 'right'] - ) -> None: - field = getattr(assoc, position + "_field") - assoc_json_entry['properties'][field.fieldname] = \ - { - 'type' : 'array', - 'items' : - { - '$ref': - '#/definitions/LanguageAsset/definitions/Asset_' + - field.asset.name - } - } + assoc: SchemaGeneratedClass, + assoc_json_entry: dict, + position: Literal['left', 'right'], + ) -> None: + field = getattr(assoc, position + '_field') + assoc_json_entry['properties'][field.fieldname] = { + 'type': 'array', + 'items': { + '$ref': '#/definitions/LanguageAsset/definitions/Asset_' + + field.asset.name + }, + } if field.maximum: - assoc_json_entry['properties'][field.fieldname]\ - ['maxItems'] = field.maximum + assoc_json_entry['properties'][field.fieldname]['maxItems'] = ( + field.maximum + ) - for asset_name, asset in self.lang_graph.assets.items(): + for asset in self.lang_graph.assets.values(): for assoc_name, assoc in asset.associations.items(): - if assoc_name not in self.json_schema['definitions']\ - ['LanguageAssociation']['definitions']: + if ( + assoc_name + not in self.json_schema['definitions']['LanguageAssociation'][ + 'definitions' + ] + ): assoc_json_entry = create_association_entry(assoc) - self.json_schema['definitions']['LanguageAssociation']\ - ['definitions']['Association_' + assoc_name] = \ - assoc_json_entry - self.json_schema['definitions']['LanguageAssociation']['oneOf'].\ - append({'$ref': '#/definitions/LanguageAssociation/' + - 'definitions/Association_' + assoc_name}) + self.json_schema['definitions']['LanguageAssociation'][ + 'definitions' + ]['Association_' + assoc_name] = assoc_json_entry + self.json_schema['definitions']['LanguageAssociation'][ + 'oneOf' + ].append( + { + '$ref': '#/definitions/LanguageAssociation/' + + 'definitions/Association_' + + assoc_name + } + ) def _create_classes(self) -> None: - """ - Create classes based on the language specification. - """ - + """Create classes based on the language specification.""" # First, we have to translate the language specification into a JSON # schema. Initialize the overall JSON schema structure. self.json_schema = { '$schema': 'http://json-schema.org/draft-04/schema#', - 'id': f"urn:mal:{__name__.replace('.', ':')}", + 'id': f'urn:mal:{__name__.replace(".", ":")}', 'title': 'LanguageObject', 'type': 'object', - 'oneOf':[ + 'oneOf': [ {'$ref': '#/definitions/LanguageAsset'}, - {'$ref': '#/definitions/LanguageAssociation'} + {'$ref': '#/definitions/LanguageAssociation'}, ], - 'definitions': {}} + 'definitions': {}, + } self.json_schema['definitions']['LanguageAsset'] = { 'title': 'LanguageAsset', 'type': 'object', 'oneOf': [], - 'definitions': {}} + 'definitions': {}, + } self.json_schema['definitions']['LanguageAssociation'] = { 'title': 'LanguageAssociation', 'type': 'object', 'oneOf': [], - 'definitions': {}} + 'definitions': {}, + } self._generate_assets() self._generate_associations() if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug - logger.debug(json.dumps(self.json_schema, indent = 2)) + logger.debug(json.dumps(self.json_schema, indent=2)) # Once we have the JSON schema we create the actual classes. builder = pjs.ObjectBuilder(self.json_schema) self.ns = builder.build_classes(standardize_names=False) def get_association_by_signature( - self, - assoc_name: str, - left_asset: str, - right_asset: str - ) -> Optional[str]: - """ - Get association name based on its signature. This is primarily + self, assoc_name: str, left_asset: str, right_asset: str + ) -> str | None: + """Get association name based on its signature. This is primarily relevant for getting the exact association full name when multiple associations with the same name exist. @@ -177,83 +182,62 @@ def get_association_by_signature( Return: The matching association name if a match is found. None if there is no match. + """ - lang_assocs_entries = self.json_schema['definitions']\ - ['LanguageAssociation']['definitions'] - if not assoc_name in lang_assocs_entries: - raise LookupError( - 'Failed to find "%s" association in the language json ' - 'schema.' % assoc_name + lang_assocs_entries = self.json_schema['definitions']['LanguageAssociation'][ + 'definitions' + ] + if assoc_name not in lang_assocs_entries: + msg = ( + f'Failed to find "{assoc_name}" association in the language json ' + 'schema.' ) + raise LookupError(msg) assoc_entry = lang_assocs_entries[assoc_name] # If the association has a oneOf property it should always have more # than just one alternative, but check just in case - if 'definitions' in assoc_entry and \ - len(assoc_entry['definitions']) > 1: - full_name = '%s_%s_%s' % ( - assoc_name, - left_asset, - right_asset - ) - full_name_flipped = '%s_%s_%s' % ( - assoc_name, - right_asset, - left_asset - ) - if not full_name in assoc_entry['definitions']: - if not full_name_flipped in assoc_entry['definitions']: - raise LookupError( - 'Failed to find "%s" or "%s" association in the ' + if 'definitions' in assoc_entry and len(assoc_entry['definitions']) > 1: + full_name = f'{assoc_name}_{left_asset}_{right_asset}' + full_name_flipped = f'{assoc_name}_{right_asset}_{left_asset}' + if full_name not in assoc_entry['definitions']: + if full_name_flipped not in assoc_entry['definitions']: + msg = ( + f'Failed to find "{full_name}" or "{full_name_flipped}" association in the ' 'language json schema.' - % (full_name, - full_name_flipped) ) - else: - return full_name_flipped - else: - return full_name - else: - return assoc_name + raise LookupError(msg) + return full_name_flipped + return full_name + return assoc_name - def get_asset_class(self, - asset_name: str - ) -> Optional[SchemaGeneratedClass]: + def get_asset_class(self, asset_name: str) -> SchemaGeneratedClass | None: class_name = 'Asset_' + asset_name if hasattr(self.ns, class_name): class_obj = getattr(self.ns, class_name) class_obj.__hash__ = lambda self: hash(self.name) return class_obj - else: - logger.warning('Could not find Asset "%s" in classes factory.' % - asset_name) - return None + logger.warning(f'Could not find Asset "{asset_name}" in classes factory.') + return None - def get_association_class(self, - assoc_name: str - ) -> Optional[SchemaGeneratedClass]: + def get_association_class(self, assoc_name: str) -> SchemaGeneratedClass | None: class_name = 'Association_' + assoc_name if hasattr(self.ns, class_name): return getattr(self.ns, class_name) - else: - logger.warning('Could not find Association "%s" in classes factory.' % - assoc_name) - return None + logger.warning(f'Could not find Association "{assoc_name}" in classes factory.') + return None - def get_association_class_by_fieldnames(self, - assoc_name: str, - fieldname1: str, - fieldname2: str - ) -> Optional[SchemaGeneratedClass]: - class_name = 'Association_%s_%s_%s' % (assoc_name, - fieldname1, fieldname2) - class_name_alt = 'Association_%s_%s_%s' % (assoc_name, - fieldname2, fieldname1) + def get_association_class_by_fieldnames( + self, assoc_name: str, fieldname1: str, fieldname2: str + ) -> SchemaGeneratedClass | None: + class_name = f'Association_{assoc_name}_{fieldname1}_{fieldname2}' + class_name_alt = f'Association_{assoc_name}_{fieldname2}_{fieldname1}' if hasattr(self.ns, class_name): return getattr(self.ns, class_name) - elif hasattr(self.ns, class_name_alt): + if hasattr(self.ns, class_name_alt): return getattr(self.ns, class_name_alt) - else: - logger.warning('Could not find Association "%s" or "%s" in ' - 'classes factory.' % (class_name, class_name_alt)) - return None + logger.warning( + f'Could not find Association "{class_name}" or "{class_name_alt}" in ' + 'classes factory.' + ) + return None diff --git a/maltoolbox/language/compiler/__init__.py b/maltoolbox/language/compiler/__init__.py index 3f1993b0..164009c7 100644 --- a/maltoolbox/language/compiler/__init__.py +++ b/maltoolbox/language/compiler/__init__.py @@ -2,27 +2,27 @@ # mypy: ignore-errors import os -from typing import Optional -from antlr4 import FileStream, CommonTokenStream +from antlr4 import CommonTokenStream, FileStream + from .mal_lexer import malLexer from .mal_parser import malParser from .mal_visitor import malVisitor class MalCompiler: - def __init__(self): + def __init__(self) -> None: self.path = None self.current_file = None - def compile(self, malfile: Optional[str] = None): + def compile(self, malfile: str | None = None): if not self.path: self.path = os.path.dirname(malfile) self.current_file = os.path.basename(malfile) input_stream = FileStream( - os.path.join(self.path, self.current_file), encoding="utf-8" + os.path.join(self.path, self.current_file), encoding='utf-8' ) lexer = malLexer(input_stream) stream = CommonTokenStream(lexer) diff --git a/maltoolbox/language/compiler/mal_lexer.py b/maltoolbox/language/compiler/mal_lexer.py index 5808c9d3..ae55dde8 100644 --- a/maltoolbox/language/compiler/mal_lexer.py +++ b/maltoolbox/language/compiler/mal_lexer.py @@ -1,7 +1,9 @@ # mypy: ignore-errors # Generated from mal.g4 by ANTLR 4.13.1 -from antlr4 import * import sys + +from antlr4 import * + if sys.version_info[1] > 5: from typing import TextIO else: @@ -10,115 +12,2581 @@ def serializedATN(): return [ - 4,0,48,296,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, - 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, - 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, - 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, - 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, - 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, - 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, - 45,2,46,7,46,2,47,7,47,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,1,1, - 1,1,1,1,1,1,1,1,1,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1, - 2,1,2,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,4,1,4,1,4,1,4,1,4,1,4,1, - 4,1,4,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,6,1,6,1,6,1,6,1,6,1, - 7,1,7,1,7,1,7,1,8,1,8,5,8,162,8,8,10,8,12,8,165,9,8,1,8,1,8,1,9, - 4,9,170,8,9,11,9,12,9,171,1,10,5,10,175,8,10,10,10,12,10,178,9,10, - 1,10,1,10,4,10,182,8,10,11,10,12,10,183,1,11,1,11,1,12,1,12,1,13, - 1,13,1,14,1,14,1,15,4,15,195,8,15,11,15,12,15,196,1,16,1,16,1,17, - 1,17,1,18,1,18,1,19,1,19,1,20,1,20,1,21,1,21,1,22,1,22,1,22,1,22, - 1,23,1,23,1,23,1,23,1,24,1,24,1,25,1,25,1,26,1,26,1,27,1,27,1,28, - 1,28,1,29,1,29,1,30,1,30,1,30,1,31,1,31,1,31,1,32,1,32,1,32,1,33, - 1,33,1,34,1,34,1,35,1,35,1,36,1,36,1,36,1,37,1,37,1,38,1,38,1,38, - 1,39,1,39,1,39,1,40,1,40,1,40,1,41,1,41,1,42,1,42,1,43,1,43,1,44, - 1,44,1,45,1,45,1,45,1,45,5,45,272,8,45,10,45,12,45,275,9,45,1,45, - 1,45,1,46,1,46,1,46,1,46,5,46,283,8,46,10,46,12,46,286,9,46,1,46, - 1,46,1,46,1,46,1,46,1,47,1,47,1,47,1,47,2,163,284,0,48,1,1,3,2,5, - 3,7,4,9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13,27,14,29,15, - 31,16,33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24,49,25,51,26, - 53,27,55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35,71,36,73,37, - 75,38,77,39,79,40,81,41,83,42,85,43,87,44,89,45,91,46,93,47,95,48, - 1,0,4,1,0,48,57,4,0,48,57,65,90,95,95,97,122,2,0,10,10,13,13,3,0, - 9,10,13,13,32,32,302,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0,7,1,0, - 0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17,1,0, - 0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0,27,1,0, - 0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0,37,1,0, - 0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0,47,1,0, - 0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0,57,1,0, - 0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0,67,1,0, - 0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0,77,1,0, - 0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0,87,1,0, - 0,0,0,89,1,0,0,0,0,91,1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,1,97,1,0, - 0,0,3,106,1,0,0,0,5,112,1,0,0,0,7,125,1,0,0,0,9,133,1,0,0,0,11,141, - 1,0,0,0,13,150,1,0,0,0,15,155,1,0,0,0,17,159,1,0,0,0,19,169,1,0, - 0,0,21,176,1,0,0,0,23,185,1,0,0,0,25,187,1,0,0,0,27,189,1,0,0,0, - 29,191,1,0,0,0,31,194,1,0,0,0,33,198,1,0,0,0,35,200,1,0,0,0,37,202, - 1,0,0,0,39,204,1,0,0,0,41,206,1,0,0,0,43,208,1,0,0,0,45,210,1,0, - 0,0,47,214,1,0,0,0,49,218,1,0,0,0,51,220,1,0,0,0,53,222,1,0,0,0, - 55,224,1,0,0,0,57,226,1,0,0,0,59,228,1,0,0,0,61,230,1,0,0,0,63,233, - 1,0,0,0,65,236,1,0,0,0,67,239,1,0,0,0,69,241,1,0,0,0,71,243,1,0, - 0,0,73,245,1,0,0,0,75,248,1,0,0,0,77,250,1,0,0,0,79,253,1,0,0,0, - 81,256,1,0,0,0,83,259,1,0,0,0,85,261,1,0,0,0,87,263,1,0,0,0,89,265, - 1,0,0,0,91,267,1,0,0,0,93,278,1,0,0,0,95,292,1,0,0,0,97,98,5,97, - 0,0,98,99,5,98,0,0,99,100,5,115,0,0,100,101,5,116,0,0,101,102,5, - 114,0,0,102,103,5,97,0,0,103,104,5,99,0,0,104,105,5,116,0,0,105, - 2,1,0,0,0,106,107,5,97,0,0,107,108,5,115,0,0,108,109,5,115,0,0,109, - 110,5,101,0,0,110,111,5,116,0,0,111,4,1,0,0,0,112,113,5,97,0,0,113, - 114,5,115,0,0,114,115,5,115,0,0,115,116,5,111,0,0,116,117,5,99,0, - 0,117,118,5,105,0,0,118,119,5,97,0,0,119,120,5,116,0,0,120,121,5, - 105,0,0,121,122,5,111,0,0,122,123,5,110,0,0,123,124,5,115,0,0,124, - 6,1,0,0,0,125,126,5,101,0,0,126,127,5,120,0,0,127,128,5,116,0,0, - 128,129,5,101,0,0,129,130,5,110,0,0,130,131,5,100,0,0,131,132,5, - 115,0,0,132,8,1,0,0,0,133,134,5,105,0,0,134,135,5,110,0,0,135,136, - 5,99,0,0,136,137,5,108,0,0,137,138,5,117,0,0,138,139,5,100,0,0,139, - 140,5,101,0,0,140,10,1,0,0,0,141,142,5,99,0,0,142,143,5,97,0,0,143, - 144,5,116,0,0,144,145,5,101,0,0,145,146,5,103,0,0,146,147,5,111, - 0,0,147,148,5,114,0,0,148,149,5,121,0,0,149,12,1,0,0,0,150,151,5, - 105,0,0,151,152,5,110,0,0,152,153,5,102,0,0,153,154,5,111,0,0,154, - 14,1,0,0,0,155,156,5,108,0,0,156,157,5,101,0,0,157,158,5,116,0,0, - 158,16,1,0,0,0,159,163,5,34,0,0,160,162,9,0,0,0,161,160,1,0,0,0, - 162,165,1,0,0,0,163,164,1,0,0,0,163,161,1,0,0,0,164,166,1,0,0,0, - 165,163,1,0,0,0,166,167,5,34,0,0,167,18,1,0,0,0,168,170,7,0,0,0, - 169,168,1,0,0,0,170,171,1,0,0,0,171,169,1,0,0,0,171,172,1,0,0,0, - 172,20,1,0,0,0,173,175,7,0,0,0,174,173,1,0,0,0,175,178,1,0,0,0,176, - 174,1,0,0,0,176,177,1,0,0,0,177,179,1,0,0,0,178,176,1,0,0,0,179, - 181,3,67,33,0,180,182,7,0,0,0,181,180,1,0,0,0,182,183,1,0,0,0,183, - 181,1,0,0,0,183,184,1,0,0,0,184,22,1,0,0,0,185,186,5,69,0,0,186, - 24,1,0,0,0,187,188,5,67,0,0,188,26,1,0,0,0,189,190,5,73,0,0,190, - 28,1,0,0,0,191,192,5,65,0,0,192,30,1,0,0,0,193,195,7,1,0,0,194,193, - 1,0,0,0,195,196,1,0,0,0,196,194,1,0,0,0,196,197,1,0,0,0,197,32,1, - 0,0,0,198,199,5,40,0,0,199,34,1,0,0,0,200,201,5,41,0,0,201,36,1, - 0,0,0,202,203,5,123,0,0,203,38,1,0,0,0,204,205,5,125,0,0,205,40, - 1,0,0,0,206,207,5,35,0,0,207,42,1,0,0,0,208,209,5,58,0,0,209,44, - 1,0,0,0,210,211,5,60,0,0,211,212,5,45,0,0,212,213,5,45,0,0,213,46, - 1,0,0,0,214,215,5,45,0,0,215,216,5,45,0,0,216,217,5,62,0,0,217,48, - 1,0,0,0,218,219,5,91,0,0,219,50,1,0,0,0,220,221,5,93,0,0,221,52, - 1,0,0,0,222,223,5,42,0,0,223,54,1,0,0,0,224,225,5,49,0,0,225,56, - 1,0,0,0,226,227,5,61,0,0,227,58,1,0,0,0,228,229,5,45,0,0,229,60, - 1,0,0,0,230,231,5,47,0,0,231,232,5,92,0,0,232,62,1,0,0,0,233,234, - 5,92,0,0,234,235,5,47,0,0,235,64,1,0,0,0,236,237,5,46,0,0,237,238, - 5,46,0,0,238,66,1,0,0,0,239,240,5,46,0,0,240,68,1,0,0,0,241,242, - 5,38,0,0,242,70,1,0,0,0,243,244,5,124,0,0,244,72,1,0,0,0,245,246, - 5,33,0,0,246,247,5,69,0,0,247,74,1,0,0,0,248,249,5,64,0,0,249,76, - 1,0,0,0,250,251,5,60,0,0,251,252,5,45,0,0,252,78,1,0,0,0,253,254, - 5,43,0,0,254,255,5,62,0,0,255,80,1,0,0,0,256,257,5,45,0,0,257,258, - 5,62,0,0,258,82,1,0,0,0,259,260,5,44,0,0,260,84,1,0,0,0,261,262, - 5,43,0,0,262,86,1,0,0,0,263,264,5,47,0,0,264,88,1,0,0,0,265,266, - 5,94,0,0,266,90,1,0,0,0,267,268,5,47,0,0,268,269,5,47,0,0,269,273, - 1,0,0,0,270,272,8,2,0,0,271,270,1,0,0,0,272,275,1,0,0,0,273,271, - 1,0,0,0,273,274,1,0,0,0,274,276,1,0,0,0,275,273,1,0,0,0,276,277, - 6,45,0,0,277,92,1,0,0,0,278,279,5,47,0,0,279,280,5,42,0,0,280,284, - 1,0,0,0,281,283,9,0,0,0,282,281,1,0,0,0,283,286,1,0,0,0,284,285, - 1,0,0,0,284,282,1,0,0,0,285,287,1,0,0,0,286,284,1,0,0,0,287,288, - 5,42,0,0,288,289,5,47,0,0,289,290,1,0,0,0,290,291,6,46,0,0,291,94, - 1,0,0,0,292,293,7,3,0,0,293,294,1,0,0,0,294,295,6,47,0,0,295,96, - 1,0,0,0,8,0,163,171,176,183,196,273,284,1,6,0,0 + 4, + 0, + 48, + 296, + 6, + -1, + 2, + 0, + 7, + 0, + 2, + 1, + 7, + 1, + 2, + 2, + 7, + 2, + 2, + 3, + 7, + 3, + 2, + 4, + 7, + 4, + 2, + 5, + 7, + 5, + 2, + 6, + 7, + 6, + 2, + 7, + 7, + 7, + 2, + 8, + 7, + 8, + 2, + 9, + 7, + 9, + 2, + 10, + 7, + 10, + 2, + 11, + 7, + 11, + 2, + 12, + 7, + 12, + 2, + 13, + 7, + 13, + 2, + 14, + 7, + 14, + 2, + 15, + 7, + 15, + 2, + 16, + 7, + 16, + 2, + 17, + 7, + 17, + 2, + 18, + 7, + 18, + 2, + 19, + 7, + 19, + 2, + 20, + 7, + 20, + 2, + 21, + 7, + 21, + 2, + 22, + 7, + 22, + 2, + 23, + 7, + 23, + 2, + 24, + 7, + 24, + 2, + 25, + 7, + 25, + 2, + 26, + 7, + 26, + 2, + 27, + 7, + 27, + 2, + 28, + 7, + 28, + 2, + 29, + 7, + 29, + 2, + 30, + 7, + 30, + 2, + 31, + 7, + 31, + 2, + 32, + 7, + 32, + 2, + 33, + 7, + 33, + 2, + 34, + 7, + 34, + 2, + 35, + 7, + 35, + 2, + 36, + 7, + 36, + 2, + 37, + 7, + 37, + 2, + 38, + 7, + 38, + 2, + 39, + 7, + 39, + 2, + 40, + 7, + 40, + 2, + 41, + 7, + 41, + 2, + 42, + 7, + 42, + 2, + 43, + 7, + 43, + 2, + 44, + 7, + 44, + 2, + 45, + 7, + 45, + 2, + 46, + 7, + 46, + 2, + 47, + 7, + 47, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 4, + 1, + 4, + 1, + 4, + 1, + 4, + 1, + 4, + 1, + 4, + 1, + 4, + 1, + 4, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 6, + 1, + 6, + 1, + 6, + 1, + 6, + 1, + 6, + 1, + 7, + 1, + 7, + 1, + 7, + 1, + 7, + 1, + 8, + 1, + 8, + 5, + 8, + 162, + 8, + 8, + 10, + 8, + 12, + 8, + 165, + 9, + 8, + 1, + 8, + 1, + 8, + 1, + 9, + 4, + 9, + 170, + 8, + 9, + 11, + 9, + 12, + 9, + 171, + 1, + 10, + 5, + 10, + 175, + 8, + 10, + 10, + 10, + 12, + 10, + 178, + 9, + 10, + 1, + 10, + 1, + 10, + 4, + 10, + 182, + 8, + 10, + 11, + 10, + 12, + 10, + 183, + 1, + 11, + 1, + 11, + 1, + 12, + 1, + 12, + 1, + 13, + 1, + 13, + 1, + 14, + 1, + 14, + 1, + 15, + 4, + 15, + 195, + 8, + 15, + 11, + 15, + 12, + 15, + 196, + 1, + 16, + 1, + 16, + 1, + 17, + 1, + 17, + 1, + 18, + 1, + 18, + 1, + 19, + 1, + 19, + 1, + 20, + 1, + 20, + 1, + 21, + 1, + 21, + 1, + 22, + 1, + 22, + 1, + 22, + 1, + 22, + 1, + 23, + 1, + 23, + 1, + 23, + 1, + 23, + 1, + 24, + 1, + 24, + 1, + 25, + 1, + 25, + 1, + 26, + 1, + 26, + 1, + 27, + 1, + 27, + 1, + 28, + 1, + 28, + 1, + 29, + 1, + 29, + 1, + 30, + 1, + 30, + 1, + 30, + 1, + 31, + 1, + 31, + 1, + 31, + 1, + 32, + 1, + 32, + 1, + 32, + 1, + 33, + 1, + 33, + 1, + 34, + 1, + 34, + 1, + 35, + 1, + 35, + 1, + 36, + 1, + 36, + 1, + 36, + 1, + 37, + 1, + 37, + 1, + 38, + 1, + 38, + 1, + 38, + 1, + 39, + 1, + 39, + 1, + 39, + 1, + 40, + 1, + 40, + 1, + 40, + 1, + 41, + 1, + 41, + 1, + 42, + 1, + 42, + 1, + 43, + 1, + 43, + 1, + 44, + 1, + 44, + 1, + 45, + 1, + 45, + 1, + 45, + 1, + 45, + 5, + 45, + 272, + 8, + 45, + 10, + 45, + 12, + 45, + 275, + 9, + 45, + 1, + 45, + 1, + 45, + 1, + 46, + 1, + 46, + 1, + 46, + 1, + 46, + 5, + 46, + 283, + 8, + 46, + 10, + 46, + 12, + 46, + 286, + 9, + 46, + 1, + 46, + 1, + 46, + 1, + 46, + 1, + 46, + 1, + 46, + 1, + 47, + 1, + 47, + 1, + 47, + 1, + 47, + 2, + 163, + 284, + 0, + 48, + 1, + 1, + 3, + 2, + 5, + 3, + 7, + 4, + 9, + 5, + 11, + 6, + 13, + 7, + 15, + 8, + 17, + 9, + 19, + 10, + 21, + 11, + 23, + 12, + 25, + 13, + 27, + 14, + 29, + 15, + 31, + 16, + 33, + 17, + 35, + 18, + 37, + 19, + 39, + 20, + 41, + 21, + 43, + 22, + 45, + 23, + 47, + 24, + 49, + 25, + 51, + 26, + 53, + 27, + 55, + 28, + 57, + 29, + 59, + 30, + 61, + 31, + 63, + 32, + 65, + 33, + 67, + 34, + 69, + 35, + 71, + 36, + 73, + 37, + 75, + 38, + 77, + 39, + 79, + 40, + 81, + 41, + 83, + 42, + 85, + 43, + 87, + 44, + 89, + 45, + 91, + 46, + 93, + 47, + 95, + 48, + 1, + 0, + 4, + 1, + 0, + 48, + 57, + 4, + 0, + 48, + 57, + 65, + 90, + 95, + 95, + 97, + 122, + 2, + 0, + 10, + 10, + 13, + 13, + 3, + 0, + 9, + 10, + 13, + 13, + 32, + 32, + 302, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 3, + 1, + 0, + 0, + 0, + 0, + 5, + 1, + 0, + 0, + 0, + 0, + 7, + 1, + 0, + 0, + 0, + 0, + 9, + 1, + 0, + 0, + 0, + 0, + 11, + 1, + 0, + 0, + 0, + 0, + 13, + 1, + 0, + 0, + 0, + 0, + 15, + 1, + 0, + 0, + 0, + 0, + 17, + 1, + 0, + 0, + 0, + 0, + 19, + 1, + 0, + 0, + 0, + 0, + 21, + 1, + 0, + 0, + 0, + 0, + 23, + 1, + 0, + 0, + 0, + 0, + 25, + 1, + 0, + 0, + 0, + 0, + 27, + 1, + 0, + 0, + 0, + 0, + 29, + 1, + 0, + 0, + 0, + 0, + 31, + 1, + 0, + 0, + 0, + 0, + 33, + 1, + 0, + 0, + 0, + 0, + 35, + 1, + 0, + 0, + 0, + 0, + 37, + 1, + 0, + 0, + 0, + 0, + 39, + 1, + 0, + 0, + 0, + 0, + 41, + 1, + 0, + 0, + 0, + 0, + 43, + 1, + 0, + 0, + 0, + 0, + 45, + 1, + 0, + 0, + 0, + 0, + 47, + 1, + 0, + 0, + 0, + 0, + 49, + 1, + 0, + 0, + 0, + 0, + 51, + 1, + 0, + 0, + 0, + 0, + 53, + 1, + 0, + 0, + 0, + 0, + 55, + 1, + 0, + 0, + 0, + 0, + 57, + 1, + 0, + 0, + 0, + 0, + 59, + 1, + 0, + 0, + 0, + 0, + 61, + 1, + 0, + 0, + 0, + 0, + 63, + 1, + 0, + 0, + 0, + 0, + 65, + 1, + 0, + 0, + 0, + 0, + 67, + 1, + 0, + 0, + 0, + 0, + 69, + 1, + 0, + 0, + 0, + 0, + 71, + 1, + 0, + 0, + 0, + 0, + 73, + 1, + 0, + 0, + 0, + 0, + 75, + 1, + 0, + 0, + 0, + 0, + 77, + 1, + 0, + 0, + 0, + 0, + 79, + 1, + 0, + 0, + 0, + 0, + 81, + 1, + 0, + 0, + 0, + 0, + 83, + 1, + 0, + 0, + 0, + 0, + 85, + 1, + 0, + 0, + 0, + 0, + 87, + 1, + 0, + 0, + 0, + 0, + 89, + 1, + 0, + 0, + 0, + 0, + 91, + 1, + 0, + 0, + 0, + 0, + 93, + 1, + 0, + 0, + 0, + 0, + 95, + 1, + 0, + 0, + 0, + 1, + 97, + 1, + 0, + 0, + 0, + 3, + 106, + 1, + 0, + 0, + 0, + 5, + 112, + 1, + 0, + 0, + 0, + 7, + 125, + 1, + 0, + 0, + 0, + 9, + 133, + 1, + 0, + 0, + 0, + 11, + 141, + 1, + 0, + 0, + 0, + 13, + 150, + 1, + 0, + 0, + 0, + 15, + 155, + 1, + 0, + 0, + 0, + 17, + 159, + 1, + 0, + 0, + 0, + 19, + 169, + 1, + 0, + 0, + 0, + 21, + 176, + 1, + 0, + 0, + 0, + 23, + 185, + 1, + 0, + 0, + 0, + 25, + 187, + 1, + 0, + 0, + 0, + 27, + 189, + 1, + 0, + 0, + 0, + 29, + 191, + 1, + 0, + 0, + 0, + 31, + 194, + 1, + 0, + 0, + 0, + 33, + 198, + 1, + 0, + 0, + 0, + 35, + 200, + 1, + 0, + 0, + 0, + 37, + 202, + 1, + 0, + 0, + 0, + 39, + 204, + 1, + 0, + 0, + 0, + 41, + 206, + 1, + 0, + 0, + 0, + 43, + 208, + 1, + 0, + 0, + 0, + 45, + 210, + 1, + 0, + 0, + 0, + 47, + 214, + 1, + 0, + 0, + 0, + 49, + 218, + 1, + 0, + 0, + 0, + 51, + 220, + 1, + 0, + 0, + 0, + 53, + 222, + 1, + 0, + 0, + 0, + 55, + 224, + 1, + 0, + 0, + 0, + 57, + 226, + 1, + 0, + 0, + 0, + 59, + 228, + 1, + 0, + 0, + 0, + 61, + 230, + 1, + 0, + 0, + 0, + 63, + 233, + 1, + 0, + 0, + 0, + 65, + 236, + 1, + 0, + 0, + 0, + 67, + 239, + 1, + 0, + 0, + 0, + 69, + 241, + 1, + 0, + 0, + 0, + 71, + 243, + 1, + 0, + 0, + 0, + 73, + 245, + 1, + 0, + 0, + 0, + 75, + 248, + 1, + 0, + 0, + 0, + 77, + 250, + 1, + 0, + 0, + 0, + 79, + 253, + 1, + 0, + 0, + 0, + 81, + 256, + 1, + 0, + 0, + 0, + 83, + 259, + 1, + 0, + 0, + 0, + 85, + 261, + 1, + 0, + 0, + 0, + 87, + 263, + 1, + 0, + 0, + 0, + 89, + 265, + 1, + 0, + 0, + 0, + 91, + 267, + 1, + 0, + 0, + 0, + 93, + 278, + 1, + 0, + 0, + 0, + 95, + 292, + 1, + 0, + 0, + 0, + 97, + 98, + 5, + 97, + 0, + 0, + 98, + 99, + 5, + 98, + 0, + 0, + 99, + 100, + 5, + 115, + 0, + 0, + 100, + 101, + 5, + 116, + 0, + 0, + 101, + 102, + 5, + 114, + 0, + 0, + 102, + 103, + 5, + 97, + 0, + 0, + 103, + 104, + 5, + 99, + 0, + 0, + 104, + 105, + 5, + 116, + 0, + 0, + 105, + 2, + 1, + 0, + 0, + 0, + 106, + 107, + 5, + 97, + 0, + 0, + 107, + 108, + 5, + 115, + 0, + 0, + 108, + 109, + 5, + 115, + 0, + 0, + 109, + 110, + 5, + 101, + 0, + 0, + 110, + 111, + 5, + 116, + 0, + 0, + 111, + 4, + 1, + 0, + 0, + 0, + 112, + 113, + 5, + 97, + 0, + 0, + 113, + 114, + 5, + 115, + 0, + 0, + 114, + 115, + 5, + 115, + 0, + 0, + 115, + 116, + 5, + 111, + 0, + 0, + 116, + 117, + 5, + 99, + 0, + 0, + 117, + 118, + 5, + 105, + 0, + 0, + 118, + 119, + 5, + 97, + 0, + 0, + 119, + 120, + 5, + 116, + 0, + 0, + 120, + 121, + 5, + 105, + 0, + 0, + 121, + 122, + 5, + 111, + 0, + 0, + 122, + 123, + 5, + 110, + 0, + 0, + 123, + 124, + 5, + 115, + 0, + 0, + 124, + 6, + 1, + 0, + 0, + 0, + 125, + 126, + 5, + 101, + 0, + 0, + 126, + 127, + 5, + 120, + 0, + 0, + 127, + 128, + 5, + 116, + 0, + 0, + 128, + 129, + 5, + 101, + 0, + 0, + 129, + 130, + 5, + 110, + 0, + 0, + 130, + 131, + 5, + 100, + 0, + 0, + 131, + 132, + 5, + 115, + 0, + 0, + 132, + 8, + 1, + 0, + 0, + 0, + 133, + 134, + 5, + 105, + 0, + 0, + 134, + 135, + 5, + 110, + 0, + 0, + 135, + 136, + 5, + 99, + 0, + 0, + 136, + 137, + 5, + 108, + 0, + 0, + 137, + 138, + 5, + 117, + 0, + 0, + 138, + 139, + 5, + 100, + 0, + 0, + 139, + 140, + 5, + 101, + 0, + 0, + 140, + 10, + 1, + 0, + 0, + 0, + 141, + 142, + 5, + 99, + 0, + 0, + 142, + 143, + 5, + 97, + 0, + 0, + 143, + 144, + 5, + 116, + 0, + 0, + 144, + 145, + 5, + 101, + 0, + 0, + 145, + 146, + 5, + 103, + 0, + 0, + 146, + 147, + 5, + 111, + 0, + 0, + 147, + 148, + 5, + 114, + 0, + 0, + 148, + 149, + 5, + 121, + 0, + 0, + 149, + 12, + 1, + 0, + 0, + 0, + 150, + 151, + 5, + 105, + 0, + 0, + 151, + 152, + 5, + 110, + 0, + 0, + 152, + 153, + 5, + 102, + 0, + 0, + 153, + 154, + 5, + 111, + 0, + 0, + 154, + 14, + 1, + 0, + 0, + 0, + 155, + 156, + 5, + 108, + 0, + 0, + 156, + 157, + 5, + 101, + 0, + 0, + 157, + 158, + 5, + 116, + 0, + 0, + 158, + 16, + 1, + 0, + 0, + 0, + 159, + 163, + 5, + 34, + 0, + 0, + 160, + 162, + 9, + 0, + 0, + 0, + 161, + 160, + 1, + 0, + 0, + 0, + 162, + 165, + 1, + 0, + 0, + 0, + 163, + 164, + 1, + 0, + 0, + 0, + 163, + 161, + 1, + 0, + 0, + 0, + 164, + 166, + 1, + 0, + 0, + 0, + 165, + 163, + 1, + 0, + 0, + 0, + 166, + 167, + 5, + 34, + 0, + 0, + 167, + 18, + 1, + 0, + 0, + 0, + 168, + 170, + 7, + 0, + 0, + 0, + 169, + 168, + 1, + 0, + 0, + 0, + 170, + 171, + 1, + 0, + 0, + 0, + 171, + 169, + 1, + 0, + 0, + 0, + 171, + 172, + 1, + 0, + 0, + 0, + 172, + 20, + 1, + 0, + 0, + 0, + 173, + 175, + 7, + 0, + 0, + 0, + 174, + 173, + 1, + 0, + 0, + 0, + 175, + 178, + 1, + 0, + 0, + 0, + 176, + 174, + 1, + 0, + 0, + 0, + 176, + 177, + 1, + 0, + 0, + 0, + 177, + 179, + 1, + 0, + 0, + 0, + 178, + 176, + 1, + 0, + 0, + 0, + 179, + 181, + 3, + 67, + 33, + 0, + 180, + 182, + 7, + 0, + 0, + 0, + 181, + 180, + 1, + 0, + 0, + 0, + 182, + 183, + 1, + 0, + 0, + 0, + 183, + 181, + 1, + 0, + 0, + 0, + 183, + 184, + 1, + 0, + 0, + 0, + 184, + 22, + 1, + 0, + 0, + 0, + 185, + 186, + 5, + 69, + 0, + 0, + 186, + 24, + 1, + 0, + 0, + 0, + 187, + 188, + 5, + 67, + 0, + 0, + 188, + 26, + 1, + 0, + 0, + 0, + 189, + 190, + 5, + 73, + 0, + 0, + 190, + 28, + 1, + 0, + 0, + 0, + 191, + 192, + 5, + 65, + 0, + 0, + 192, + 30, + 1, + 0, + 0, + 0, + 193, + 195, + 7, + 1, + 0, + 0, + 194, + 193, + 1, + 0, + 0, + 0, + 195, + 196, + 1, + 0, + 0, + 0, + 196, + 194, + 1, + 0, + 0, + 0, + 196, + 197, + 1, + 0, + 0, + 0, + 197, + 32, + 1, + 0, + 0, + 0, + 198, + 199, + 5, + 40, + 0, + 0, + 199, + 34, + 1, + 0, + 0, + 0, + 200, + 201, + 5, + 41, + 0, + 0, + 201, + 36, + 1, + 0, + 0, + 0, + 202, + 203, + 5, + 123, + 0, + 0, + 203, + 38, + 1, + 0, + 0, + 0, + 204, + 205, + 5, + 125, + 0, + 0, + 205, + 40, + 1, + 0, + 0, + 0, + 206, + 207, + 5, + 35, + 0, + 0, + 207, + 42, + 1, + 0, + 0, + 0, + 208, + 209, + 5, + 58, + 0, + 0, + 209, + 44, + 1, + 0, + 0, + 0, + 210, + 211, + 5, + 60, + 0, + 0, + 211, + 212, + 5, + 45, + 0, + 0, + 212, + 213, + 5, + 45, + 0, + 0, + 213, + 46, + 1, + 0, + 0, + 0, + 214, + 215, + 5, + 45, + 0, + 0, + 215, + 216, + 5, + 45, + 0, + 0, + 216, + 217, + 5, + 62, + 0, + 0, + 217, + 48, + 1, + 0, + 0, + 0, + 218, + 219, + 5, + 91, + 0, + 0, + 219, + 50, + 1, + 0, + 0, + 0, + 220, + 221, + 5, + 93, + 0, + 0, + 221, + 52, + 1, + 0, + 0, + 0, + 222, + 223, + 5, + 42, + 0, + 0, + 223, + 54, + 1, + 0, + 0, + 0, + 224, + 225, + 5, + 49, + 0, + 0, + 225, + 56, + 1, + 0, + 0, + 0, + 226, + 227, + 5, + 61, + 0, + 0, + 227, + 58, + 1, + 0, + 0, + 0, + 228, + 229, + 5, + 45, + 0, + 0, + 229, + 60, + 1, + 0, + 0, + 0, + 230, + 231, + 5, + 47, + 0, + 0, + 231, + 232, + 5, + 92, + 0, + 0, + 232, + 62, + 1, + 0, + 0, + 0, + 233, + 234, + 5, + 92, + 0, + 0, + 234, + 235, + 5, + 47, + 0, + 0, + 235, + 64, + 1, + 0, + 0, + 0, + 236, + 237, + 5, + 46, + 0, + 0, + 237, + 238, + 5, + 46, + 0, + 0, + 238, + 66, + 1, + 0, + 0, + 0, + 239, + 240, + 5, + 46, + 0, + 0, + 240, + 68, + 1, + 0, + 0, + 0, + 241, + 242, + 5, + 38, + 0, + 0, + 242, + 70, + 1, + 0, + 0, + 0, + 243, + 244, + 5, + 124, + 0, + 0, + 244, + 72, + 1, + 0, + 0, + 0, + 245, + 246, + 5, + 33, + 0, + 0, + 246, + 247, + 5, + 69, + 0, + 0, + 247, + 74, + 1, + 0, + 0, + 0, + 248, + 249, + 5, + 64, + 0, + 0, + 249, + 76, + 1, + 0, + 0, + 0, + 250, + 251, + 5, + 60, + 0, + 0, + 251, + 252, + 5, + 45, + 0, + 0, + 252, + 78, + 1, + 0, + 0, + 0, + 253, + 254, + 5, + 43, + 0, + 0, + 254, + 255, + 5, + 62, + 0, + 0, + 255, + 80, + 1, + 0, + 0, + 0, + 256, + 257, + 5, + 45, + 0, + 0, + 257, + 258, + 5, + 62, + 0, + 0, + 258, + 82, + 1, + 0, + 0, + 0, + 259, + 260, + 5, + 44, + 0, + 0, + 260, + 84, + 1, + 0, + 0, + 0, + 261, + 262, + 5, + 43, + 0, + 0, + 262, + 86, + 1, + 0, + 0, + 0, + 263, + 264, + 5, + 47, + 0, + 0, + 264, + 88, + 1, + 0, + 0, + 0, + 265, + 266, + 5, + 94, + 0, + 0, + 266, + 90, + 1, + 0, + 0, + 0, + 267, + 268, + 5, + 47, + 0, + 0, + 268, + 269, + 5, + 47, + 0, + 0, + 269, + 273, + 1, + 0, + 0, + 0, + 270, + 272, + 8, + 2, + 0, + 0, + 271, + 270, + 1, + 0, + 0, + 0, + 272, + 275, + 1, + 0, + 0, + 0, + 273, + 271, + 1, + 0, + 0, + 0, + 273, + 274, + 1, + 0, + 0, + 0, + 274, + 276, + 1, + 0, + 0, + 0, + 275, + 273, + 1, + 0, + 0, + 0, + 276, + 277, + 6, + 45, + 0, + 0, + 277, + 92, + 1, + 0, + 0, + 0, + 278, + 279, + 5, + 47, + 0, + 0, + 279, + 280, + 5, + 42, + 0, + 0, + 280, + 284, + 1, + 0, + 0, + 0, + 281, + 283, + 9, + 0, + 0, + 0, + 282, + 281, + 1, + 0, + 0, + 0, + 283, + 286, + 1, + 0, + 0, + 0, + 284, + 285, + 1, + 0, + 0, + 0, + 284, + 282, + 1, + 0, + 0, + 0, + 285, + 287, + 1, + 0, + 0, + 0, + 286, + 284, + 1, + 0, + 0, + 0, + 287, + 288, + 5, + 42, + 0, + 0, + 288, + 289, + 5, + 47, + 0, + 0, + 289, + 290, + 1, + 0, + 0, + 0, + 290, + 291, + 6, + 46, + 0, + 0, + 291, + 94, + 1, + 0, + 0, + 0, + 292, + 293, + 7, + 3, + 0, + 0, + 293, + 294, + 1, + 0, + 0, + 0, + 294, + 295, + 6, + 47, + 0, + 0, + 295, + 96, + 1, + 0, + 0, + 0, + 8, + 0, + 163, + 171, + 176, + 183, + 196, + 273, + 284, + 1, + 6, + 0, + 0, ] -class malLexer(Lexer): +class malLexer(Lexer): atn = ATNDeserializer().deserialize(serializedATN()) - decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + decisionsToDFA = [DFA(ds, i) for i, ds in enumerate(atn.decisionToState)] ABSTRACT = 1 ASSET = 2 @@ -169,44 +2637,165 @@ class malLexer(Lexer): MULTILINE_COMMENT = 47 WS = 48 - channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + channelNames = ['DEFAULT_TOKEN_CHANNEL', 'HIDDEN'] - modeNames = [ "DEFAULT_MODE" ] + modeNames = ['DEFAULT_MODE'] - literalNames = [ "", - "'abstract'", "'asset'", "'associations'", "'extends'", "'include'", - "'category'", "'info'", "'let'", "'E'", "'C'", "'I'", "'A'", - "'('", "')'", "'{'", "'}'", "'#'", "':'", "'<--'", "'-->'", - "'['", "']'", "'*'", "'1'", "'='", "'-'", "'/\\'", "'\\/'", - "'..'", "'.'", "'&'", "'|'", "'!E'", "'@'", "'<-'", "'+>'", - "'->'", "','", "'+'", "'/'", "'^'" ] + literalNames = [ + '', + "'abstract'", + "'asset'", + "'associations'", + "'extends'", + "'include'", + "'category'", + "'info'", + "'let'", + "'E'", + "'C'", + "'I'", + "'A'", + "'('", + "')'", + "'{'", + "'}'", + "'#'", + "':'", + "'<--'", + "'-->'", + "'['", + "']'", + "'*'", + "'1'", + "'='", + "'-'", + "'/\\'", + "'\\/'", + "'..'", + "'.'", + "'&'", + "'|'", + "'!E'", + "'@'", + "'<-'", + "'+>'", + "'->'", + "','", + "'+'", + "'/'", + "'^'", + ] - symbolicNames = [ "", - "ABSTRACT", "ASSET", "ASSOCIATIONS", "EXTENDS", "INCLUDE", "CATEGORY", - "INFO", "LET", "STRING", "INT", "FLOAT", "EXISTS", "C", "I", - "A", "ID", "LPAREN", "RPAREN", "LCURLY", "RCURLY", "HASH", "COLON", - "LARROW", "RARROW", "LSQUARE", "RSQUARE", "STAR", "ONE", "ASSIGN", - "MINUS", "INTERSECT", "UNION", "RANGE", "DOT", "AND", "OR", - "NOTEXISTS", "AT", "REQUIRES", "INHERITS", "LEADSTO", "COMMA", - "PLUS", "DIVIDE", "POWER", "INLINE_COMMENT", "MULTILINE_COMMENT", - "WS" ] + symbolicNames = [ + '', + 'ABSTRACT', + 'ASSET', + 'ASSOCIATIONS', + 'EXTENDS', + 'INCLUDE', + 'CATEGORY', + 'INFO', + 'LET', + 'STRING', + 'INT', + 'FLOAT', + 'EXISTS', + 'C', + 'I', + 'A', + 'ID', + 'LPAREN', + 'RPAREN', + 'LCURLY', + 'RCURLY', + 'HASH', + 'COLON', + 'LARROW', + 'RARROW', + 'LSQUARE', + 'RSQUARE', + 'STAR', + 'ONE', + 'ASSIGN', + 'MINUS', + 'INTERSECT', + 'UNION', + 'RANGE', + 'DOT', + 'AND', + 'OR', + 'NOTEXISTS', + 'AT', + 'REQUIRES', + 'INHERITS', + 'LEADSTO', + 'COMMA', + 'PLUS', + 'DIVIDE', + 'POWER', + 'INLINE_COMMENT', + 'MULTILINE_COMMENT', + 'WS', + ] - ruleNames = [ "ABSTRACT", "ASSET", "ASSOCIATIONS", "EXTENDS", "INCLUDE", - "CATEGORY", "INFO", "LET", "STRING", "INT", "FLOAT", "EXISTS", - "C", "I", "A", "ID", "LPAREN", "RPAREN", "LCURLY", "RCURLY", - "HASH", "COLON", "LARROW", "RARROW", "LSQUARE", "RSQUARE", - "STAR", "ONE", "ASSIGN", "MINUS", "INTERSECT", "UNION", - "RANGE", "DOT", "AND", "OR", "NOTEXISTS", "AT", "REQUIRES", - "INHERITS", "LEADSTO", "COMMA", "PLUS", "DIVIDE", "POWER", - "INLINE_COMMENT", "MULTILINE_COMMENT", "WS" ] + ruleNames = [ + 'ABSTRACT', + 'ASSET', + 'ASSOCIATIONS', + 'EXTENDS', + 'INCLUDE', + 'CATEGORY', + 'INFO', + 'LET', + 'STRING', + 'INT', + 'FLOAT', + 'EXISTS', + 'C', + 'I', + 'A', + 'ID', + 'LPAREN', + 'RPAREN', + 'LCURLY', + 'RCURLY', + 'HASH', + 'COLON', + 'LARROW', + 'RARROW', + 'LSQUARE', + 'RSQUARE', + 'STAR', + 'ONE', + 'ASSIGN', + 'MINUS', + 'INTERSECT', + 'UNION', + 'RANGE', + 'DOT', + 'AND', + 'OR', + 'NOTEXISTS', + 'AT', + 'REQUIRES', + 'INHERITS', + 'LEADSTO', + 'COMMA', + 'PLUS', + 'DIVIDE', + 'POWER', + 'INLINE_COMMENT', + 'MULTILINE_COMMENT', + 'WS', + ] - grammarFileName = "mal.g4" + grammarFileName = 'mal.g4' - def __init__(self, input=None, output:TextIO = sys.stdout): + def __init__(self, input=None, output: TextIO = sys.stdout) -> None: super().__init__(input, output) - self.checkVersion("4.13.1") - self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self.checkVersion('4.13.1') + self._interp = LexerATNSimulator( + self, self.atn, self.decisionsToDFA, PredictionContextCache() + ) self._actions = None self._predicates = None - - diff --git a/maltoolbox/language/compiler/mal_parser.py b/maltoolbox/language/compiler/mal_parser.py index d686c5e1..29f12d58 100644 --- a/maltoolbox/language/compiler/mal_parser.py +++ b/maltoolbox/language/compiler/mal_parser.py @@ -1,163 +1,3024 @@ # mypy: ignore-errors # Generated from mal.g4 by ANTLR 4.13.1 # encoding: utf-8 -from antlr4 import * import sys + +from antlr4 import * + if sys.version_info[1] > 5: - from typing import TextIO + from typing import TextIO else: - from typing.io import TextIO + from typing.io import TextIO + def serializedATN(): return [ - 4,1,48,338,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, - 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, - 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, - 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, - 2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33, - 7,33,1,0,4,0,70,8,0,11,0,12,0,71,1,0,3,0,75,8,0,1,1,1,1,1,1,1,1, - 3,1,81,8,1,1,2,1,2,1,2,1,3,1,3,1,3,1,3,1,3,1,4,1,4,1,4,5,4,94,8, - 4,10,4,12,4,97,9,4,1,4,1,4,5,4,101,8,4,10,4,12,4,104,9,4,1,4,1,4, - 1,5,1,5,1,5,1,5,1,5,1,6,3,6,114,8,6,1,6,1,6,1,6,1,6,3,6,120,8,6, - 1,6,5,6,123,8,6,10,6,12,6,126,9,6,1,6,1,6,1,6,5,6,131,8,6,10,6,12, - 6,134,9,6,1,6,1,6,1,7,1,7,1,7,5,7,141,8,7,10,7,12,7,144,9,7,1,7, - 3,7,147,8,7,1,7,3,7,150,8,7,1,7,5,7,153,8,7,10,7,12,7,156,9,7,1, - 7,3,7,159,8,7,1,7,3,7,162,8,7,1,8,1,8,1,9,1,9,1,9,1,10,1,10,1,10, - 1,10,5,10,173,8,10,10,10,12,10,176,9,10,1,10,1,10,1,11,1,11,1,12, - 1,12,1,12,1,12,1,13,1,13,1,13,5,13,189,8,13,10,13,12,13,192,9,13, - 1,14,1,14,1,14,5,14,197,8,14,10,14,12,14,200,9,14,1,15,1,15,1,15, - 3,15,205,8,15,1,16,1,16,1,16,1,16,1,16,1,16,3,16,213,8,16,1,17,1, - 17,1,17,1,17,1,17,5,17,220,8,17,10,17,12,17,223,9,17,3,17,225,8, - 17,1,17,3,17,228,8,17,1,18,1,18,1,18,1,18,5,18,234,8,18,10,18,12, - 18,237,9,18,1,19,1,19,1,19,1,19,5,19,243,8,19,10,19,12,19,246,9, - 19,1,20,1,20,1,21,1,21,1,21,1,21,1,21,1,22,1,22,1,22,1,22,5,22,259, - 8,22,10,22,12,22,262,9,22,1,23,1,23,1,23,5,23,267,8,23,10,23,12, - 23,270,9,23,1,24,1,24,1,24,1,24,1,24,1,24,1,24,1,24,1,24,3,24,281, - 8,24,1,24,3,24,284,8,24,1,24,5,24,287,8,24,10,24,12,24,290,9,24, - 1,25,1,25,1,26,1,26,1,26,1,26,1,27,1,27,1,28,1,28,1,28,5,28,303, - 8,28,10,28,12,28,306,9,28,1,28,1,28,1,29,1,29,1,29,1,29,1,29,1,29, - 1,29,1,29,1,29,1,29,5,29,320,8,29,10,29,12,29,323,9,29,1,30,1,30, - 1,30,1,30,1,31,1,31,1,31,3,31,332,8,31,1,32,1,32,1,33,1,33,1,33, - 0,0,34,0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40, - 42,44,46,48,50,52,54,56,58,60,62,64,66,0,8,3,0,12,12,21,21,35,37, - 1,0,13,15,2,0,30,30,43,43,2,0,27,27,44,44,1,0,40,41,1,0,10,11,1, - 0,30,32,2,0,10,10,27,27,341,0,74,1,0,0,0,2,80,1,0,0,0,4,82,1,0,0, - 0,6,85,1,0,0,0,8,90,1,0,0,0,10,107,1,0,0,0,12,113,1,0,0,0,14,137, - 1,0,0,0,16,163,1,0,0,0,18,165,1,0,0,0,20,168,1,0,0,0,22,179,1,0, - 0,0,24,181,1,0,0,0,26,185,1,0,0,0,28,193,1,0,0,0,30,201,1,0,0,0, - 32,212,1,0,0,0,34,214,1,0,0,0,36,229,1,0,0,0,38,238,1,0,0,0,40,247, - 1,0,0,0,42,249,1,0,0,0,44,254,1,0,0,0,46,263,1,0,0,0,48,280,1,0, - 0,0,50,291,1,0,0,0,52,293,1,0,0,0,54,297,1,0,0,0,56,299,1,0,0,0, - 58,309,1,0,0,0,60,324,1,0,0,0,62,328,1,0,0,0,64,333,1,0,0,0,66,335, - 1,0,0,0,68,70,3,2,1,0,69,68,1,0,0,0,70,71,1,0,0,0,71,69,1,0,0,0, - 71,72,1,0,0,0,72,75,1,0,0,0,73,75,5,0,0,1,74,69,1,0,0,0,74,73,1, - 0,0,0,75,1,1,0,0,0,76,81,3,4,2,0,77,81,3,6,3,0,78,81,3,8,4,0,79, - 81,3,56,28,0,80,76,1,0,0,0,80,77,1,0,0,0,80,78,1,0,0,0,80,79,1,0, - 0,0,81,3,1,0,0,0,82,83,5,5,0,0,83,84,5,9,0,0,84,5,1,0,0,0,85,86, - 5,21,0,0,86,87,5,16,0,0,87,88,5,22,0,0,88,89,5,9,0,0,89,7,1,0,0, - 0,90,91,5,6,0,0,91,95,5,16,0,0,92,94,3,10,5,0,93,92,1,0,0,0,94,97, - 1,0,0,0,95,93,1,0,0,0,95,96,1,0,0,0,96,98,1,0,0,0,97,95,1,0,0,0, - 98,102,5,19,0,0,99,101,3,12,6,0,100,99,1,0,0,0,101,104,1,0,0,0,102, - 100,1,0,0,0,102,103,1,0,0,0,103,105,1,0,0,0,104,102,1,0,0,0,105, - 106,5,20,0,0,106,9,1,0,0,0,107,108,5,16,0,0,108,109,5,7,0,0,109, - 110,5,22,0,0,110,111,5,9,0,0,111,11,1,0,0,0,112,114,5,1,0,0,113, - 112,1,0,0,0,113,114,1,0,0,0,114,115,1,0,0,0,115,116,5,2,0,0,116, - 119,5,16,0,0,117,118,5,4,0,0,118,120,5,16,0,0,119,117,1,0,0,0,119, - 120,1,0,0,0,120,124,1,0,0,0,121,123,3,10,5,0,122,121,1,0,0,0,123, - 126,1,0,0,0,124,122,1,0,0,0,124,125,1,0,0,0,125,127,1,0,0,0,126, - 124,1,0,0,0,127,132,5,19,0,0,128,131,3,14,7,0,129,131,3,42,21,0, - 130,128,1,0,0,0,130,129,1,0,0,0,131,134,1,0,0,0,132,130,1,0,0,0, - 132,133,1,0,0,0,133,135,1,0,0,0,134,132,1,0,0,0,135,136,5,20,0,0, - 136,13,1,0,0,0,137,138,3,16,8,0,138,142,5,16,0,0,139,141,3,18,9, - 0,140,139,1,0,0,0,141,144,1,0,0,0,142,140,1,0,0,0,142,143,1,0,0, - 0,143,146,1,0,0,0,144,142,1,0,0,0,145,147,3,20,10,0,146,145,1,0, - 0,0,146,147,1,0,0,0,147,149,1,0,0,0,148,150,3,24,12,0,149,148,1, - 0,0,0,149,150,1,0,0,0,150,154,1,0,0,0,151,153,3,10,5,0,152,151,1, - 0,0,0,153,156,1,0,0,0,154,152,1,0,0,0,154,155,1,0,0,0,155,158,1, - 0,0,0,156,154,1,0,0,0,157,159,3,36,18,0,158,157,1,0,0,0,158,159, - 1,0,0,0,159,161,1,0,0,0,160,162,3,38,19,0,161,160,1,0,0,0,161,162, - 1,0,0,0,162,15,1,0,0,0,163,164,7,0,0,0,164,17,1,0,0,0,165,166,5, - 38,0,0,166,167,5,16,0,0,167,19,1,0,0,0,168,169,5,19,0,0,169,174, - 3,22,11,0,170,171,5,42,0,0,171,173,3,22,11,0,172,170,1,0,0,0,173, - 176,1,0,0,0,174,172,1,0,0,0,174,175,1,0,0,0,175,177,1,0,0,0,176, - 174,1,0,0,0,177,178,5,20,0,0,178,21,1,0,0,0,179,180,7,1,0,0,180, - 23,1,0,0,0,181,182,5,25,0,0,182,183,3,26,13,0,183,184,5,26,0,0,184, - 25,1,0,0,0,185,190,3,28,14,0,186,187,7,2,0,0,187,189,3,28,14,0,188, - 186,1,0,0,0,189,192,1,0,0,0,190,188,1,0,0,0,190,191,1,0,0,0,191, - 27,1,0,0,0,192,190,1,0,0,0,193,198,3,30,15,0,194,195,7,3,0,0,195, - 197,3,30,15,0,196,194,1,0,0,0,197,200,1,0,0,0,198,196,1,0,0,0,198, - 199,1,0,0,0,199,29,1,0,0,0,200,198,1,0,0,0,201,204,3,32,16,0,202, - 203,5,45,0,0,203,205,3,32,16,0,204,202,1,0,0,0,204,205,1,0,0,0,205, - 31,1,0,0,0,206,213,3,34,17,0,207,208,5,17,0,0,208,209,3,26,13,0, - 209,210,5,18,0,0,210,213,1,0,0,0,211,213,3,40,20,0,212,206,1,0,0, - 0,212,207,1,0,0,0,212,211,1,0,0,0,213,33,1,0,0,0,214,227,5,16,0, - 0,215,224,5,17,0,0,216,221,3,40,20,0,217,218,5,42,0,0,218,220,3, - 40,20,0,219,217,1,0,0,0,220,223,1,0,0,0,221,219,1,0,0,0,221,222, - 1,0,0,0,222,225,1,0,0,0,223,221,1,0,0,0,224,216,1,0,0,0,224,225, - 1,0,0,0,225,226,1,0,0,0,226,228,5,18,0,0,227,215,1,0,0,0,227,228, - 1,0,0,0,228,35,1,0,0,0,229,230,5,39,0,0,230,235,3,44,22,0,231,232, - 5,42,0,0,232,234,3,44,22,0,233,231,1,0,0,0,234,237,1,0,0,0,235,233, - 1,0,0,0,235,236,1,0,0,0,236,37,1,0,0,0,237,235,1,0,0,0,238,239,7, - 4,0,0,239,244,3,44,22,0,240,241,5,42,0,0,241,243,3,44,22,0,242,240, - 1,0,0,0,243,246,1,0,0,0,244,242,1,0,0,0,244,245,1,0,0,0,245,39,1, - 0,0,0,246,244,1,0,0,0,247,248,7,5,0,0,248,41,1,0,0,0,249,250,5,8, - 0,0,250,251,5,16,0,0,251,252,5,29,0,0,252,253,3,44,22,0,253,43,1, - 0,0,0,254,260,3,46,23,0,255,256,3,54,27,0,256,257,3,46,23,0,257, - 259,1,0,0,0,258,255,1,0,0,0,259,262,1,0,0,0,260,258,1,0,0,0,260, - 261,1,0,0,0,261,45,1,0,0,0,262,260,1,0,0,0,263,268,3,48,24,0,264, - 265,5,34,0,0,265,267,3,48,24,0,266,264,1,0,0,0,267,270,1,0,0,0,268, - 266,1,0,0,0,268,269,1,0,0,0,269,47,1,0,0,0,270,268,1,0,0,0,271,272, - 5,17,0,0,272,273,3,44,22,0,273,274,5,18,0,0,274,281,1,0,0,0,275, - 276,3,50,25,0,276,277,5,17,0,0,277,278,5,18,0,0,278,281,1,0,0,0, - 279,281,5,16,0,0,280,271,1,0,0,0,280,275,1,0,0,0,280,279,1,0,0,0, - 281,283,1,0,0,0,282,284,5,27,0,0,283,282,1,0,0,0,283,284,1,0,0,0, - 284,288,1,0,0,0,285,287,3,52,26,0,286,285,1,0,0,0,287,290,1,0,0, - 0,288,286,1,0,0,0,288,289,1,0,0,0,289,49,1,0,0,0,290,288,1,0,0,0, - 291,292,5,16,0,0,292,51,1,0,0,0,293,294,5,25,0,0,294,295,5,16,0, - 0,295,296,5,26,0,0,296,53,1,0,0,0,297,298,7,6,0,0,298,55,1,0,0,0, - 299,300,5,3,0,0,300,304,5,19,0,0,301,303,3,58,29,0,302,301,1,0,0, - 0,303,306,1,0,0,0,304,302,1,0,0,0,304,305,1,0,0,0,305,307,1,0,0, - 0,306,304,1,0,0,0,307,308,5,20,0,0,308,57,1,0,0,0,309,310,5,16,0, - 0,310,311,3,60,30,0,311,312,3,62,31,0,312,313,5,23,0,0,313,314,3, - 66,33,0,314,315,5,24,0,0,315,316,3,62,31,0,316,317,3,60,30,0,317, - 321,5,16,0,0,318,320,3,10,5,0,319,318,1,0,0,0,320,323,1,0,0,0,321, - 319,1,0,0,0,321,322,1,0,0,0,322,59,1,0,0,0,323,321,1,0,0,0,324,325, - 5,25,0,0,325,326,5,16,0,0,326,327,5,26,0,0,327,61,1,0,0,0,328,331, - 3,64,32,0,329,330,5,33,0,0,330,332,3,64,32,0,331,329,1,0,0,0,331, - 332,1,0,0,0,332,63,1,0,0,0,333,334,7,7,0,0,334,65,1,0,0,0,335,336, - 5,16,0,0,336,67,1,0,0,0,34,71,74,80,95,102,113,119,124,130,132,142, - 146,149,154,158,161,174,190,198,204,212,221,224,227,235,244,260, - 268,280,283,288,304,321,331 + 4, + 1, + 48, + 338, + 2, + 0, + 7, + 0, + 2, + 1, + 7, + 1, + 2, + 2, + 7, + 2, + 2, + 3, + 7, + 3, + 2, + 4, + 7, + 4, + 2, + 5, + 7, + 5, + 2, + 6, + 7, + 6, + 2, + 7, + 7, + 7, + 2, + 8, + 7, + 8, + 2, + 9, + 7, + 9, + 2, + 10, + 7, + 10, + 2, + 11, + 7, + 11, + 2, + 12, + 7, + 12, + 2, + 13, + 7, + 13, + 2, + 14, + 7, + 14, + 2, + 15, + 7, + 15, + 2, + 16, + 7, + 16, + 2, + 17, + 7, + 17, + 2, + 18, + 7, + 18, + 2, + 19, + 7, + 19, + 2, + 20, + 7, + 20, + 2, + 21, + 7, + 21, + 2, + 22, + 7, + 22, + 2, + 23, + 7, + 23, + 2, + 24, + 7, + 24, + 2, + 25, + 7, + 25, + 2, + 26, + 7, + 26, + 2, + 27, + 7, + 27, + 2, + 28, + 7, + 28, + 2, + 29, + 7, + 29, + 2, + 30, + 7, + 30, + 2, + 31, + 7, + 31, + 2, + 32, + 7, + 32, + 2, + 33, + 7, + 33, + 1, + 0, + 4, + 0, + 70, + 8, + 0, + 11, + 0, + 12, + 0, + 71, + 1, + 0, + 3, + 0, + 75, + 8, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 3, + 1, + 81, + 8, + 1, + 1, + 2, + 1, + 2, + 1, + 2, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 4, + 1, + 4, + 1, + 4, + 5, + 4, + 94, + 8, + 4, + 10, + 4, + 12, + 4, + 97, + 9, + 4, + 1, + 4, + 1, + 4, + 5, + 4, + 101, + 8, + 4, + 10, + 4, + 12, + 4, + 104, + 9, + 4, + 1, + 4, + 1, + 4, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 5, + 1, + 6, + 3, + 6, + 114, + 8, + 6, + 1, + 6, + 1, + 6, + 1, + 6, + 1, + 6, + 3, + 6, + 120, + 8, + 6, + 1, + 6, + 5, + 6, + 123, + 8, + 6, + 10, + 6, + 12, + 6, + 126, + 9, + 6, + 1, + 6, + 1, + 6, + 1, + 6, + 5, + 6, + 131, + 8, + 6, + 10, + 6, + 12, + 6, + 134, + 9, + 6, + 1, + 6, + 1, + 6, + 1, + 7, + 1, + 7, + 1, + 7, + 5, + 7, + 141, + 8, + 7, + 10, + 7, + 12, + 7, + 144, + 9, + 7, + 1, + 7, + 3, + 7, + 147, + 8, + 7, + 1, + 7, + 3, + 7, + 150, + 8, + 7, + 1, + 7, + 5, + 7, + 153, + 8, + 7, + 10, + 7, + 12, + 7, + 156, + 9, + 7, + 1, + 7, + 3, + 7, + 159, + 8, + 7, + 1, + 7, + 3, + 7, + 162, + 8, + 7, + 1, + 8, + 1, + 8, + 1, + 9, + 1, + 9, + 1, + 9, + 1, + 10, + 1, + 10, + 1, + 10, + 1, + 10, + 5, + 10, + 173, + 8, + 10, + 10, + 10, + 12, + 10, + 176, + 9, + 10, + 1, + 10, + 1, + 10, + 1, + 11, + 1, + 11, + 1, + 12, + 1, + 12, + 1, + 12, + 1, + 12, + 1, + 13, + 1, + 13, + 1, + 13, + 5, + 13, + 189, + 8, + 13, + 10, + 13, + 12, + 13, + 192, + 9, + 13, + 1, + 14, + 1, + 14, + 1, + 14, + 5, + 14, + 197, + 8, + 14, + 10, + 14, + 12, + 14, + 200, + 9, + 14, + 1, + 15, + 1, + 15, + 1, + 15, + 3, + 15, + 205, + 8, + 15, + 1, + 16, + 1, + 16, + 1, + 16, + 1, + 16, + 1, + 16, + 1, + 16, + 3, + 16, + 213, + 8, + 16, + 1, + 17, + 1, + 17, + 1, + 17, + 1, + 17, + 1, + 17, + 5, + 17, + 220, + 8, + 17, + 10, + 17, + 12, + 17, + 223, + 9, + 17, + 3, + 17, + 225, + 8, + 17, + 1, + 17, + 3, + 17, + 228, + 8, + 17, + 1, + 18, + 1, + 18, + 1, + 18, + 1, + 18, + 5, + 18, + 234, + 8, + 18, + 10, + 18, + 12, + 18, + 237, + 9, + 18, + 1, + 19, + 1, + 19, + 1, + 19, + 1, + 19, + 5, + 19, + 243, + 8, + 19, + 10, + 19, + 12, + 19, + 246, + 9, + 19, + 1, + 20, + 1, + 20, + 1, + 21, + 1, + 21, + 1, + 21, + 1, + 21, + 1, + 21, + 1, + 22, + 1, + 22, + 1, + 22, + 1, + 22, + 5, + 22, + 259, + 8, + 22, + 10, + 22, + 12, + 22, + 262, + 9, + 22, + 1, + 23, + 1, + 23, + 1, + 23, + 5, + 23, + 267, + 8, + 23, + 10, + 23, + 12, + 23, + 270, + 9, + 23, + 1, + 24, + 1, + 24, + 1, + 24, + 1, + 24, + 1, + 24, + 1, + 24, + 1, + 24, + 1, + 24, + 1, + 24, + 3, + 24, + 281, + 8, + 24, + 1, + 24, + 3, + 24, + 284, + 8, + 24, + 1, + 24, + 5, + 24, + 287, + 8, + 24, + 10, + 24, + 12, + 24, + 290, + 9, + 24, + 1, + 25, + 1, + 25, + 1, + 26, + 1, + 26, + 1, + 26, + 1, + 26, + 1, + 27, + 1, + 27, + 1, + 28, + 1, + 28, + 1, + 28, + 5, + 28, + 303, + 8, + 28, + 10, + 28, + 12, + 28, + 306, + 9, + 28, + 1, + 28, + 1, + 28, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 1, + 29, + 5, + 29, + 320, + 8, + 29, + 10, + 29, + 12, + 29, + 323, + 9, + 29, + 1, + 30, + 1, + 30, + 1, + 30, + 1, + 30, + 1, + 31, + 1, + 31, + 1, + 31, + 3, + 31, + 332, + 8, + 31, + 1, + 32, + 1, + 32, + 1, + 33, + 1, + 33, + 1, + 33, + 0, + 0, + 34, + 0, + 2, + 4, + 6, + 8, + 10, + 12, + 14, + 16, + 18, + 20, + 22, + 24, + 26, + 28, + 30, + 32, + 34, + 36, + 38, + 40, + 42, + 44, + 46, + 48, + 50, + 52, + 54, + 56, + 58, + 60, + 62, + 64, + 66, + 0, + 8, + 3, + 0, + 12, + 12, + 21, + 21, + 35, + 37, + 1, + 0, + 13, + 15, + 2, + 0, + 30, + 30, + 43, + 43, + 2, + 0, + 27, + 27, + 44, + 44, + 1, + 0, + 40, + 41, + 1, + 0, + 10, + 11, + 1, + 0, + 30, + 32, + 2, + 0, + 10, + 10, + 27, + 27, + 341, + 0, + 74, + 1, + 0, + 0, + 0, + 2, + 80, + 1, + 0, + 0, + 0, + 4, + 82, + 1, + 0, + 0, + 0, + 6, + 85, + 1, + 0, + 0, + 0, + 8, + 90, + 1, + 0, + 0, + 0, + 10, + 107, + 1, + 0, + 0, + 0, + 12, + 113, + 1, + 0, + 0, + 0, + 14, + 137, + 1, + 0, + 0, + 0, + 16, + 163, + 1, + 0, + 0, + 0, + 18, + 165, + 1, + 0, + 0, + 0, + 20, + 168, + 1, + 0, + 0, + 0, + 22, + 179, + 1, + 0, + 0, + 0, + 24, + 181, + 1, + 0, + 0, + 0, + 26, + 185, + 1, + 0, + 0, + 0, + 28, + 193, + 1, + 0, + 0, + 0, + 30, + 201, + 1, + 0, + 0, + 0, + 32, + 212, + 1, + 0, + 0, + 0, + 34, + 214, + 1, + 0, + 0, + 0, + 36, + 229, + 1, + 0, + 0, + 0, + 38, + 238, + 1, + 0, + 0, + 0, + 40, + 247, + 1, + 0, + 0, + 0, + 42, + 249, + 1, + 0, + 0, + 0, + 44, + 254, + 1, + 0, + 0, + 0, + 46, + 263, + 1, + 0, + 0, + 0, + 48, + 280, + 1, + 0, + 0, + 0, + 50, + 291, + 1, + 0, + 0, + 0, + 52, + 293, + 1, + 0, + 0, + 0, + 54, + 297, + 1, + 0, + 0, + 0, + 56, + 299, + 1, + 0, + 0, + 0, + 58, + 309, + 1, + 0, + 0, + 0, + 60, + 324, + 1, + 0, + 0, + 0, + 62, + 328, + 1, + 0, + 0, + 0, + 64, + 333, + 1, + 0, + 0, + 0, + 66, + 335, + 1, + 0, + 0, + 0, + 68, + 70, + 3, + 2, + 1, + 0, + 69, + 68, + 1, + 0, + 0, + 0, + 70, + 71, + 1, + 0, + 0, + 0, + 71, + 69, + 1, + 0, + 0, + 0, + 71, + 72, + 1, + 0, + 0, + 0, + 72, + 75, + 1, + 0, + 0, + 0, + 73, + 75, + 5, + 0, + 0, + 1, + 74, + 69, + 1, + 0, + 0, + 0, + 74, + 73, + 1, + 0, + 0, + 0, + 75, + 1, + 1, + 0, + 0, + 0, + 76, + 81, + 3, + 4, + 2, + 0, + 77, + 81, + 3, + 6, + 3, + 0, + 78, + 81, + 3, + 8, + 4, + 0, + 79, + 81, + 3, + 56, + 28, + 0, + 80, + 76, + 1, + 0, + 0, + 0, + 80, + 77, + 1, + 0, + 0, + 0, + 80, + 78, + 1, + 0, + 0, + 0, + 80, + 79, + 1, + 0, + 0, + 0, + 81, + 3, + 1, + 0, + 0, + 0, + 82, + 83, + 5, + 5, + 0, + 0, + 83, + 84, + 5, + 9, + 0, + 0, + 84, + 5, + 1, + 0, + 0, + 0, + 85, + 86, + 5, + 21, + 0, + 0, + 86, + 87, + 5, + 16, + 0, + 0, + 87, + 88, + 5, + 22, + 0, + 0, + 88, + 89, + 5, + 9, + 0, + 0, + 89, + 7, + 1, + 0, + 0, + 0, + 90, + 91, + 5, + 6, + 0, + 0, + 91, + 95, + 5, + 16, + 0, + 0, + 92, + 94, + 3, + 10, + 5, + 0, + 93, + 92, + 1, + 0, + 0, + 0, + 94, + 97, + 1, + 0, + 0, + 0, + 95, + 93, + 1, + 0, + 0, + 0, + 95, + 96, + 1, + 0, + 0, + 0, + 96, + 98, + 1, + 0, + 0, + 0, + 97, + 95, + 1, + 0, + 0, + 0, + 98, + 102, + 5, + 19, + 0, + 0, + 99, + 101, + 3, + 12, + 6, + 0, + 100, + 99, + 1, + 0, + 0, + 0, + 101, + 104, + 1, + 0, + 0, + 0, + 102, + 100, + 1, + 0, + 0, + 0, + 102, + 103, + 1, + 0, + 0, + 0, + 103, + 105, + 1, + 0, + 0, + 0, + 104, + 102, + 1, + 0, + 0, + 0, + 105, + 106, + 5, + 20, + 0, + 0, + 106, + 9, + 1, + 0, + 0, + 0, + 107, + 108, + 5, + 16, + 0, + 0, + 108, + 109, + 5, + 7, + 0, + 0, + 109, + 110, + 5, + 22, + 0, + 0, + 110, + 111, + 5, + 9, + 0, + 0, + 111, + 11, + 1, + 0, + 0, + 0, + 112, + 114, + 5, + 1, + 0, + 0, + 113, + 112, + 1, + 0, + 0, + 0, + 113, + 114, + 1, + 0, + 0, + 0, + 114, + 115, + 1, + 0, + 0, + 0, + 115, + 116, + 5, + 2, + 0, + 0, + 116, + 119, + 5, + 16, + 0, + 0, + 117, + 118, + 5, + 4, + 0, + 0, + 118, + 120, + 5, + 16, + 0, + 0, + 119, + 117, + 1, + 0, + 0, + 0, + 119, + 120, + 1, + 0, + 0, + 0, + 120, + 124, + 1, + 0, + 0, + 0, + 121, + 123, + 3, + 10, + 5, + 0, + 122, + 121, + 1, + 0, + 0, + 0, + 123, + 126, + 1, + 0, + 0, + 0, + 124, + 122, + 1, + 0, + 0, + 0, + 124, + 125, + 1, + 0, + 0, + 0, + 125, + 127, + 1, + 0, + 0, + 0, + 126, + 124, + 1, + 0, + 0, + 0, + 127, + 132, + 5, + 19, + 0, + 0, + 128, + 131, + 3, + 14, + 7, + 0, + 129, + 131, + 3, + 42, + 21, + 0, + 130, + 128, + 1, + 0, + 0, + 0, + 130, + 129, + 1, + 0, + 0, + 0, + 131, + 134, + 1, + 0, + 0, + 0, + 132, + 130, + 1, + 0, + 0, + 0, + 132, + 133, + 1, + 0, + 0, + 0, + 133, + 135, + 1, + 0, + 0, + 0, + 134, + 132, + 1, + 0, + 0, + 0, + 135, + 136, + 5, + 20, + 0, + 0, + 136, + 13, + 1, + 0, + 0, + 0, + 137, + 138, + 3, + 16, + 8, + 0, + 138, + 142, + 5, + 16, + 0, + 0, + 139, + 141, + 3, + 18, + 9, + 0, + 140, + 139, + 1, + 0, + 0, + 0, + 141, + 144, + 1, + 0, + 0, + 0, + 142, + 140, + 1, + 0, + 0, + 0, + 142, + 143, + 1, + 0, + 0, + 0, + 143, + 146, + 1, + 0, + 0, + 0, + 144, + 142, + 1, + 0, + 0, + 0, + 145, + 147, + 3, + 20, + 10, + 0, + 146, + 145, + 1, + 0, + 0, + 0, + 146, + 147, + 1, + 0, + 0, + 0, + 147, + 149, + 1, + 0, + 0, + 0, + 148, + 150, + 3, + 24, + 12, + 0, + 149, + 148, + 1, + 0, + 0, + 0, + 149, + 150, + 1, + 0, + 0, + 0, + 150, + 154, + 1, + 0, + 0, + 0, + 151, + 153, + 3, + 10, + 5, + 0, + 152, + 151, + 1, + 0, + 0, + 0, + 153, + 156, + 1, + 0, + 0, + 0, + 154, + 152, + 1, + 0, + 0, + 0, + 154, + 155, + 1, + 0, + 0, + 0, + 155, + 158, + 1, + 0, + 0, + 0, + 156, + 154, + 1, + 0, + 0, + 0, + 157, + 159, + 3, + 36, + 18, + 0, + 158, + 157, + 1, + 0, + 0, + 0, + 158, + 159, + 1, + 0, + 0, + 0, + 159, + 161, + 1, + 0, + 0, + 0, + 160, + 162, + 3, + 38, + 19, + 0, + 161, + 160, + 1, + 0, + 0, + 0, + 161, + 162, + 1, + 0, + 0, + 0, + 162, + 15, + 1, + 0, + 0, + 0, + 163, + 164, + 7, + 0, + 0, + 0, + 164, + 17, + 1, + 0, + 0, + 0, + 165, + 166, + 5, + 38, + 0, + 0, + 166, + 167, + 5, + 16, + 0, + 0, + 167, + 19, + 1, + 0, + 0, + 0, + 168, + 169, + 5, + 19, + 0, + 0, + 169, + 174, + 3, + 22, + 11, + 0, + 170, + 171, + 5, + 42, + 0, + 0, + 171, + 173, + 3, + 22, + 11, + 0, + 172, + 170, + 1, + 0, + 0, + 0, + 173, + 176, + 1, + 0, + 0, + 0, + 174, + 172, + 1, + 0, + 0, + 0, + 174, + 175, + 1, + 0, + 0, + 0, + 175, + 177, + 1, + 0, + 0, + 0, + 176, + 174, + 1, + 0, + 0, + 0, + 177, + 178, + 5, + 20, + 0, + 0, + 178, + 21, + 1, + 0, + 0, + 0, + 179, + 180, + 7, + 1, + 0, + 0, + 180, + 23, + 1, + 0, + 0, + 0, + 181, + 182, + 5, + 25, + 0, + 0, + 182, + 183, + 3, + 26, + 13, + 0, + 183, + 184, + 5, + 26, + 0, + 0, + 184, + 25, + 1, + 0, + 0, + 0, + 185, + 190, + 3, + 28, + 14, + 0, + 186, + 187, + 7, + 2, + 0, + 0, + 187, + 189, + 3, + 28, + 14, + 0, + 188, + 186, + 1, + 0, + 0, + 0, + 189, + 192, + 1, + 0, + 0, + 0, + 190, + 188, + 1, + 0, + 0, + 0, + 190, + 191, + 1, + 0, + 0, + 0, + 191, + 27, + 1, + 0, + 0, + 0, + 192, + 190, + 1, + 0, + 0, + 0, + 193, + 198, + 3, + 30, + 15, + 0, + 194, + 195, + 7, + 3, + 0, + 0, + 195, + 197, + 3, + 30, + 15, + 0, + 196, + 194, + 1, + 0, + 0, + 0, + 197, + 200, + 1, + 0, + 0, + 0, + 198, + 196, + 1, + 0, + 0, + 0, + 198, + 199, + 1, + 0, + 0, + 0, + 199, + 29, + 1, + 0, + 0, + 0, + 200, + 198, + 1, + 0, + 0, + 0, + 201, + 204, + 3, + 32, + 16, + 0, + 202, + 203, + 5, + 45, + 0, + 0, + 203, + 205, + 3, + 32, + 16, + 0, + 204, + 202, + 1, + 0, + 0, + 0, + 204, + 205, + 1, + 0, + 0, + 0, + 205, + 31, + 1, + 0, + 0, + 0, + 206, + 213, + 3, + 34, + 17, + 0, + 207, + 208, + 5, + 17, + 0, + 0, + 208, + 209, + 3, + 26, + 13, + 0, + 209, + 210, + 5, + 18, + 0, + 0, + 210, + 213, + 1, + 0, + 0, + 0, + 211, + 213, + 3, + 40, + 20, + 0, + 212, + 206, + 1, + 0, + 0, + 0, + 212, + 207, + 1, + 0, + 0, + 0, + 212, + 211, + 1, + 0, + 0, + 0, + 213, + 33, + 1, + 0, + 0, + 0, + 214, + 227, + 5, + 16, + 0, + 0, + 215, + 224, + 5, + 17, + 0, + 0, + 216, + 221, + 3, + 40, + 20, + 0, + 217, + 218, + 5, + 42, + 0, + 0, + 218, + 220, + 3, + 40, + 20, + 0, + 219, + 217, + 1, + 0, + 0, + 0, + 220, + 223, + 1, + 0, + 0, + 0, + 221, + 219, + 1, + 0, + 0, + 0, + 221, + 222, + 1, + 0, + 0, + 0, + 222, + 225, + 1, + 0, + 0, + 0, + 223, + 221, + 1, + 0, + 0, + 0, + 224, + 216, + 1, + 0, + 0, + 0, + 224, + 225, + 1, + 0, + 0, + 0, + 225, + 226, + 1, + 0, + 0, + 0, + 226, + 228, + 5, + 18, + 0, + 0, + 227, + 215, + 1, + 0, + 0, + 0, + 227, + 228, + 1, + 0, + 0, + 0, + 228, + 35, + 1, + 0, + 0, + 0, + 229, + 230, + 5, + 39, + 0, + 0, + 230, + 235, + 3, + 44, + 22, + 0, + 231, + 232, + 5, + 42, + 0, + 0, + 232, + 234, + 3, + 44, + 22, + 0, + 233, + 231, + 1, + 0, + 0, + 0, + 234, + 237, + 1, + 0, + 0, + 0, + 235, + 233, + 1, + 0, + 0, + 0, + 235, + 236, + 1, + 0, + 0, + 0, + 236, + 37, + 1, + 0, + 0, + 0, + 237, + 235, + 1, + 0, + 0, + 0, + 238, + 239, + 7, + 4, + 0, + 0, + 239, + 244, + 3, + 44, + 22, + 0, + 240, + 241, + 5, + 42, + 0, + 0, + 241, + 243, + 3, + 44, + 22, + 0, + 242, + 240, + 1, + 0, + 0, + 0, + 243, + 246, + 1, + 0, + 0, + 0, + 244, + 242, + 1, + 0, + 0, + 0, + 244, + 245, + 1, + 0, + 0, + 0, + 245, + 39, + 1, + 0, + 0, + 0, + 246, + 244, + 1, + 0, + 0, + 0, + 247, + 248, + 7, + 5, + 0, + 0, + 248, + 41, + 1, + 0, + 0, + 0, + 249, + 250, + 5, + 8, + 0, + 0, + 250, + 251, + 5, + 16, + 0, + 0, + 251, + 252, + 5, + 29, + 0, + 0, + 252, + 253, + 3, + 44, + 22, + 0, + 253, + 43, + 1, + 0, + 0, + 0, + 254, + 260, + 3, + 46, + 23, + 0, + 255, + 256, + 3, + 54, + 27, + 0, + 256, + 257, + 3, + 46, + 23, + 0, + 257, + 259, + 1, + 0, + 0, + 0, + 258, + 255, + 1, + 0, + 0, + 0, + 259, + 262, + 1, + 0, + 0, + 0, + 260, + 258, + 1, + 0, + 0, + 0, + 260, + 261, + 1, + 0, + 0, + 0, + 261, + 45, + 1, + 0, + 0, + 0, + 262, + 260, + 1, + 0, + 0, + 0, + 263, + 268, + 3, + 48, + 24, + 0, + 264, + 265, + 5, + 34, + 0, + 0, + 265, + 267, + 3, + 48, + 24, + 0, + 266, + 264, + 1, + 0, + 0, + 0, + 267, + 270, + 1, + 0, + 0, + 0, + 268, + 266, + 1, + 0, + 0, + 0, + 268, + 269, + 1, + 0, + 0, + 0, + 269, + 47, + 1, + 0, + 0, + 0, + 270, + 268, + 1, + 0, + 0, + 0, + 271, + 272, + 5, + 17, + 0, + 0, + 272, + 273, + 3, + 44, + 22, + 0, + 273, + 274, + 5, + 18, + 0, + 0, + 274, + 281, + 1, + 0, + 0, + 0, + 275, + 276, + 3, + 50, + 25, + 0, + 276, + 277, + 5, + 17, + 0, + 0, + 277, + 278, + 5, + 18, + 0, + 0, + 278, + 281, + 1, + 0, + 0, + 0, + 279, + 281, + 5, + 16, + 0, + 0, + 280, + 271, + 1, + 0, + 0, + 0, + 280, + 275, + 1, + 0, + 0, + 0, + 280, + 279, + 1, + 0, + 0, + 0, + 281, + 283, + 1, + 0, + 0, + 0, + 282, + 284, + 5, + 27, + 0, + 0, + 283, + 282, + 1, + 0, + 0, + 0, + 283, + 284, + 1, + 0, + 0, + 0, + 284, + 288, + 1, + 0, + 0, + 0, + 285, + 287, + 3, + 52, + 26, + 0, + 286, + 285, + 1, + 0, + 0, + 0, + 287, + 290, + 1, + 0, + 0, + 0, + 288, + 286, + 1, + 0, + 0, + 0, + 288, + 289, + 1, + 0, + 0, + 0, + 289, + 49, + 1, + 0, + 0, + 0, + 290, + 288, + 1, + 0, + 0, + 0, + 291, + 292, + 5, + 16, + 0, + 0, + 292, + 51, + 1, + 0, + 0, + 0, + 293, + 294, + 5, + 25, + 0, + 0, + 294, + 295, + 5, + 16, + 0, + 0, + 295, + 296, + 5, + 26, + 0, + 0, + 296, + 53, + 1, + 0, + 0, + 0, + 297, + 298, + 7, + 6, + 0, + 0, + 298, + 55, + 1, + 0, + 0, + 0, + 299, + 300, + 5, + 3, + 0, + 0, + 300, + 304, + 5, + 19, + 0, + 0, + 301, + 303, + 3, + 58, + 29, + 0, + 302, + 301, + 1, + 0, + 0, + 0, + 303, + 306, + 1, + 0, + 0, + 0, + 304, + 302, + 1, + 0, + 0, + 0, + 304, + 305, + 1, + 0, + 0, + 0, + 305, + 307, + 1, + 0, + 0, + 0, + 306, + 304, + 1, + 0, + 0, + 0, + 307, + 308, + 5, + 20, + 0, + 0, + 308, + 57, + 1, + 0, + 0, + 0, + 309, + 310, + 5, + 16, + 0, + 0, + 310, + 311, + 3, + 60, + 30, + 0, + 311, + 312, + 3, + 62, + 31, + 0, + 312, + 313, + 5, + 23, + 0, + 0, + 313, + 314, + 3, + 66, + 33, + 0, + 314, + 315, + 5, + 24, + 0, + 0, + 315, + 316, + 3, + 62, + 31, + 0, + 316, + 317, + 3, + 60, + 30, + 0, + 317, + 321, + 5, + 16, + 0, + 0, + 318, + 320, + 3, + 10, + 5, + 0, + 319, + 318, + 1, + 0, + 0, + 0, + 320, + 323, + 1, + 0, + 0, + 0, + 321, + 319, + 1, + 0, + 0, + 0, + 321, + 322, + 1, + 0, + 0, + 0, + 322, + 59, + 1, + 0, + 0, + 0, + 323, + 321, + 1, + 0, + 0, + 0, + 324, + 325, + 5, + 25, + 0, + 0, + 325, + 326, + 5, + 16, + 0, + 0, + 326, + 327, + 5, + 26, + 0, + 0, + 327, + 61, + 1, + 0, + 0, + 0, + 328, + 331, + 3, + 64, + 32, + 0, + 329, + 330, + 5, + 33, + 0, + 0, + 330, + 332, + 3, + 64, + 32, + 0, + 331, + 329, + 1, + 0, + 0, + 0, + 331, + 332, + 1, + 0, + 0, + 0, + 332, + 63, + 1, + 0, + 0, + 0, + 333, + 334, + 7, + 7, + 0, + 0, + 334, + 65, + 1, + 0, + 0, + 0, + 335, + 336, + 5, + 16, + 0, + 0, + 336, + 67, + 1, + 0, + 0, + 0, + 34, + 71, + 74, + 80, + 95, + 102, + 113, + 119, + 124, + 130, + 132, + 142, + 146, + 149, + 154, + 158, + 161, + 174, + 190, + 198, + 204, + 212, + 221, + 224, + 227, + 235, + 244, + 260, + 268, + 280, + 283, + 288, + 304, + 321, + 331, ] -class malParser ( Parser ): - grammarFileName = "mal.g4" +class malParser(Parser): + grammarFileName = 'mal.g4' atn = ATNDeserializer().deserialize(serializedATN()) - decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + decisionsToDFA = [DFA(ds, i) for i, ds in enumerate(atn.decisionToState)] sharedContextCache = PredictionContextCache() - literalNames = [ "", "'abstract'", "'asset'", "'associations'", - "'extends'", "'include'", "'category'", "'info'", "'let'", - "", "", "", "'E'", "'C'", - "'I'", "'A'", "", "'('", "')'", "'{'", "'}'", - "'#'", "':'", "'<--'", "'-->'", "'['", "']'", "'*'", - "'1'", "'='", "'-'", "'/\\'", "'\\/'", "'..'", "'.'", - "'&'", "'|'", "'!E'", "'@'", "'<-'", "'+>'", "'->'", - "','", "'+'", "'/'", "'^'" ] - - symbolicNames = [ "", "ABSTRACT", "ASSET", "ASSOCIATIONS", - "EXTENDS", "INCLUDE", "CATEGORY", "INFO", "LET", "STRING", - "INT", "FLOAT", "EXISTS", "C", "I", "A", "ID", "LPAREN", - "RPAREN", "LCURLY", "RCURLY", "HASH", "COLON", "LARROW", - "RARROW", "LSQUARE", "RSQUARE", "STAR", "ONE", "ASSIGN", - "MINUS", "INTERSECT", "UNION", "RANGE", "DOT", "AND", - "OR", "NOTEXISTS", "AT", "REQUIRES", "INHERITS", "LEADSTO", - "COMMA", "PLUS", "DIVIDE", "POWER", "INLINE_COMMENT", - "MULTILINE_COMMENT", "WS" ] + literalNames = [ + '', + "'abstract'", + "'asset'", + "'associations'", + "'extends'", + "'include'", + "'category'", + "'info'", + "'let'", + '', + '', + '', + "'E'", + "'C'", + "'I'", + "'A'", + '', + "'('", + "')'", + "'{'", + "'}'", + "'#'", + "':'", + "'<--'", + "'-->'", + "'['", + "']'", + "'*'", + "'1'", + "'='", + "'-'", + "'/\\'", + "'\\/'", + "'..'", + "'.'", + "'&'", + "'|'", + "'!E'", + "'@'", + "'<-'", + "'+>'", + "'->'", + "','", + "'+'", + "'/'", + "'^'", + ] + + symbolicNames = [ + '', + 'ABSTRACT', + 'ASSET', + 'ASSOCIATIONS', + 'EXTENDS', + 'INCLUDE', + 'CATEGORY', + 'INFO', + 'LET', + 'STRING', + 'INT', + 'FLOAT', + 'EXISTS', + 'C', + 'I', + 'A', + 'ID', + 'LPAREN', + 'RPAREN', + 'LCURLY', + 'RCURLY', + 'HASH', + 'COLON', + 'LARROW', + 'RARROW', + 'LSQUARE', + 'RSQUARE', + 'STAR', + 'ONE', + 'ASSIGN', + 'MINUS', + 'INTERSECT', + 'UNION', + 'RANGE', + 'DOT', + 'AND', + 'OR', + 'NOTEXISTS', + 'AT', + 'REQUIRES', + 'INHERITS', + 'LEADSTO', + 'COMMA', + 'PLUS', + 'DIVIDE', + 'POWER', + 'INLINE_COMMENT', + 'MULTILINE_COMMENT', + 'WS', + ] RULE_mal = 0 RULE_declaration = 1 @@ -194,85 +3055,114 @@ class malParser ( Parser ): RULE_multatom = 32 RULE_linkname = 33 - ruleNames = [ "mal", "declaration", "include", "define", "category", - "meta", "asset", "step", "steptype", "tag", "cias", "cia", - "ttc", "ttcexpr", "ttcterm", "ttcfact", "ttcatom", "ttcdist", - "precondition", "reaches", "number", "variable", "expr", - "parts", "part", "varsubst", "type", "setop", "associations", - "association", "field", "mult", "multatom", "linkname" ] + ruleNames = [ + 'mal', + 'declaration', + 'include', + 'define', + 'category', + 'meta', + 'asset', + 'step', + 'steptype', + 'tag', + 'cias', + 'cia', + 'ttc', + 'ttcexpr', + 'ttcterm', + 'ttcfact', + 'ttcatom', + 'ttcdist', + 'precondition', + 'reaches', + 'number', + 'variable', + 'expr', + 'parts', + 'part', + 'varsubst', + 'type', + 'setop', + 'associations', + 'association', + 'field', + 'mult', + 'multatom', + 'linkname', + ] EOF = Token.EOF - ABSTRACT=1 - ASSET=2 - ASSOCIATIONS=3 - EXTENDS=4 - INCLUDE=5 - CATEGORY=6 - INFO=7 - LET=8 - STRING=9 - INT=10 - FLOAT=11 - EXISTS=12 - C=13 - I=14 - A=15 - ID=16 - LPAREN=17 - RPAREN=18 - LCURLY=19 - RCURLY=20 - HASH=21 - COLON=22 - LARROW=23 - RARROW=24 - LSQUARE=25 - RSQUARE=26 - STAR=27 - ONE=28 - ASSIGN=29 - MINUS=30 - INTERSECT=31 - UNION=32 - RANGE=33 - DOT=34 - AND=35 - OR=36 - NOTEXISTS=37 - AT=38 - REQUIRES=39 - INHERITS=40 - LEADSTO=41 - COMMA=42 - PLUS=43 - DIVIDE=44 - POWER=45 - INLINE_COMMENT=46 - MULTILINE_COMMENT=47 - WS=48 - - def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + ABSTRACT = 1 + ASSET = 2 + ASSOCIATIONS = 3 + EXTENDS = 4 + INCLUDE = 5 + CATEGORY = 6 + INFO = 7 + LET = 8 + STRING = 9 + INT = 10 + FLOAT = 11 + EXISTS = 12 + C = 13 + I = 14 + A = 15 + ID = 16 + LPAREN = 17 + RPAREN = 18 + LCURLY = 19 + RCURLY = 20 + HASH = 21 + COLON = 22 + LARROW = 23 + RARROW = 24 + LSQUARE = 25 + RSQUARE = 26 + STAR = 27 + ONE = 28 + ASSIGN = 29 + MINUS = 30 + INTERSECT = 31 + UNION = 32 + RANGE = 33 + DOT = 34 + AND = 35 + OR = 36 + NOTEXISTS = 37 + AT = 38 + REQUIRES = 39 + INHERITS = 40 + LEADSTO = 41 + COMMA = 42 + PLUS = 43 + DIVIDE = 44 + POWER = 45 + INLINE_COMMENT = 46 + MULTILINE_COMMENT = 47 + WS = 48 + + def __init__(self, input: TokenStream, output: TextIO = sys.stdout) -> None: super().__init__(input, output) - self.checkVersion("4.13.1") - self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self.checkVersion('4.13.1') + self._interp = ParserATNSimulator( + self, self.atn, self.decisionsToDFA, self.sharedContextCache + ) self._predicates = None - - - class MalContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def declaration(self, i:int=None): + def declaration(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.DeclarationContext) - else: - return self.getTypedRuleContext(malParser.DeclarationContext,i) - + return self.getTypedRuleContext(malParser.DeclarationContext, i) def EOF(self): return self.getToken(malParser.EOF, 0) @@ -280,52 +3170,45 @@ def EOF(self): def getRuleIndex(self): return malParser.RULE_mal - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterMal" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterMal'): listener.enterMal(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitMal" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitMal'): listener.exitMal(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMal" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitMal'): return visitor.visitMal(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def mal(self): - localctx = malParser.MalContext(self, self._ctx, self.state) self.enterRule(localctx, 0, self.RULE_mal) - self._la = 0 # Token type + self._la = 0 # Token type try: self.state = 74 self._errHandler.sync(self) token = self._input.LA(1) - if token in [3, 5, 6, 21]: + if token in {3, 5, 6, 21}: self.enterOuterAlt(localctx, 1) - self.state = 69 + self.state = 69 self._errHandler.sync(self) _la = self._input.LA(1) while True: self.state = 68 self.declaration() - self.state = 71 + self.state = 71 self._errHandler.sync(self) _la = self._input.LA(1) - if not ((((_la) & ~0x3f) == 0 and ((1 << _la) & 2097256) != 0)): + if not (((_la) & ~0x3F) == 0 and ((1 << _la) & 2097256) != 0): break - pass - elif token in [-1]: + elif token == -1: self.enterOuterAlt(localctx, 2) self.state = 73 self.match(malParser.EOF) - pass else: raise NoViableAltException(self) @@ -337,78 +3220,66 @@ def mal(self): self.exitRule() return localctx - class DeclarationContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser def include(self): - return self.getTypedRuleContext(malParser.IncludeContext,0) - + return self.getTypedRuleContext(malParser.IncludeContext, 0) def define(self): - return self.getTypedRuleContext(malParser.DefineContext,0) - + return self.getTypedRuleContext(malParser.DefineContext, 0) def category(self): - return self.getTypedRuleContext(malParser.CategoryContext,0) - + return self.getTypedRuleContext(malParser.CategoryContext, 0) def associations(self): - return self.getTypedRuleContext(malParser.AssociationsContext,0) - + return self.getTypedRuleContext(malParser.AssociationsContext, 0) def getRuleIndex(self): return malParser.RULE_declaration - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterDeclaration" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterDeclaration'): listener.enterDeclaration(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitDeclaration" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitDeclaration'): listener.exitDeclaration(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitDeclaration" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitDeclaration'): return visitor.visitDeclaration(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def declaration(self): - localctx = malParser.DeclarationContext(self, self._ctx, self.state) self.enterRule(localctx, 2, self.RULE_declaration) try: self.state = 80 self._errHandler.sync(self) token = self._input.LA(1) - if token in [5]: + if token == 5: self.enterOuterAlt(localctx, 1) self.state = 76 self.include() - pass - elif token in [21]: + elif token == 21: self.enterOuterAlt(localctx, 2) self.state = 77 self.define() - pass - elif token in [6]: + elif token == 6: self.enterOuterAlt(localctx, 3) self.state = 78 self.category() - pass - elif token in [3]: + elif token == 3: self.enterOuterAlt(localctx, 4) self.state = 79 self.associations() - pass else: raise NoViableAltException(self) @@ -420,11 +3291,12 @@ def declaration(self): self.exitRule() return localctx - class IncludeContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -437,25 +3309,20 @@ def STRING(self): def getRuleIndex(self): return malParser.RULE_include - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterInclude" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterInclude'): listener.enterInclude(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitInclude" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitInclude'): listener.exitInclude(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitInclude" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitInclude'): return visitor.visitInclude(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def include(self): - localctx = malParser.IncludeContext(self, self._ctx, self.state) self.enterRule(localctx, 4, self.RULE_include) try: @@ -472,11 +3339,12 @@ def include(self): self.exitRule() return localctx - class DefineContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -495,25 +3363,20 @@ def STRING(self): def getRuleIndex(self): return malParser.RULE_define - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterDefine" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterDefine'): listener.enterDefine(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitDefine" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitDefine'): listener.exitDefine(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitDefine" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitDefine'): return visitor.visitDefine(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def define(self): - localctx = malParser.DefineContext(self, self._ctx, self.state) self.enterRule(localctx, 6, self.RULE_define) try: @@ -534,11 +3397,12 @@ def define(self): self.exitRule() return localctx - class CategoryContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -554,45 +3418,36 @@ def LCURLY(self): def RCURLY(self): return self.getToken(malParser.RCURLY, 0) - def meta(self, i:int=None): + def meta(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.MetaContext) - else: - return self.getTypedRuleContext(malParser.MetaContext,i) - + return self.getTypedRuleContext(malParser.MetaContext, i) - def asset(self, i:int=None): + def asset(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.AssetContext) - else: - return self.getTypedRuleContext(malParser.AssetContext,i) - + return self.getTypedRuleContext(malParser.AssetContext, i) def getRuleIndex(self): return malParser.RULE_category - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterCategory" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterCategory'): listener.enterCategory(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitCategory" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitCategory'): listener.exitCategory(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCategory" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitCategory'): return visitor.visitCategory(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def category(self): - localctx = malParser.CategoryContext(self, self._ctx, self.state) self.enterRule(localctx, 8, self.RULE_category) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 90 @@ -602,7 +3457,7 @@ def category(self): self.state = 95 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==16: + while _la == 16: self.state = 92 self.meta() self.state = 97 @@ -614,7 +3469,7 @@ def category(self): self.state = 102 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==1 or _la==2: + while _la in {1, 2}: self.state = 99 self.asset() self.state = 104 @@ -631,11 +3486,12 @@ def category(self): self.exitRule() return localctx - class MetaContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -654,25 +3510,20 @@ def STRING(self): def getRuleIndex(self): return malParser.RULE_meta - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterMeta" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterMeta'): listener.enterMeta(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitMeta" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitMeta'): listener.exitMeta(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMeta" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitMeta'): return visitor.visitMeta(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def meta(self): - localctx = malParser.MetaContext(self, self._ctx, self.state) self.enterRule(localctx, 10, self.RULE_meta) try: @@ -693,22 +3544,22 @@ def meta(self): self.exitRule() return localctx - class AssetContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser def ASSET(self): return self.getToken(malParser.ASSET, 0) - def ID(self, i:int=None): + def ID(self, i: int | None = None): if i is None: return self.getTokens(malParser.ID) - else: - return self.getToken(malParser.ID, i) + return self.getToken(malParser.ID, i) def LCURLY(self): return self.getToken(malParser.LCURLY, 0) @@ -722,62 +3573,50 @@ def ABSTRACT(self): def EXTENDS(self): return self.getToken(malParser.EXTENDS, 0) - def meta(self, i:int=None): + def meta(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.MetaContext) - else: - return self.getTypedRuleContext(malParser.MetaContext,i) - + return self.getTypedRuleContext(malParser.MetaContext, i) - def step(self, i:int=None): + def step(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.StepContext) - else: - return self.getTypedRuleContext(malParser.StepContext,i) + return self.getTypedRuleContext(malParser.StepContext, i) - - def variable(self, i:int=None): + def variable(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.VariableContext) - else: - return self.getTypedRuleContext(malParser.VariableContext,i) - + return self.getTypedRuleContext(malParser.VariableContext, i) def getRuleIndex(self): return malParser.RULE_asset - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterAsset" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterAsset'): listener.enterAsset(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitAsset" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitAsset'): listener.exitAsset(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAsset" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitAsset'): return visitor.visitAsset(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def asset(self): - localctx = malParser.AssetContext(self, self._ctx, self.state) self.enterRule(localctx, 12, self.RULE_asset) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 113 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==1: + if _la == 1: self.state = 112 self.match(malParser.ABSTRACT) - self.state = 115 self.match(malParser.ASSET) self.state = 116 @@ -785,17 +3624,16 @@ def asset(self): self.state = 119 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==4: + if _la == 4: self.state = 117 self.match(malParser.EXTENDS) self.state = 118 self.match(malParser.ID) - self.state = 124 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==16: + while _la == 16: self.state = 121 self.meta() self.state = 126 @@ -807,18 +3645,16 @@ def asset(self): self.state = 132 self._errHandler.sync(self) _la = self._input.LA(1) - while (((_la) & ~0x3f) == 0 and ((1 << _la) & 240520270080) != 0): + while ((_la) & ~0x3F) == 0 and ((1 << _la) & 240520270080) != 0: self.state = 130 self._errHandler.sync(self) token = self._input.LA(1) - if token in [12, 21, 35, 36, 37]: + if token in {12, 21, 35, 36, 37}: self.state = 128 self.step() - pass - elif token in [8]: + elif token == 8: self.state = 129 self.variable() - pass else: raise NoViableAltException(self) @@ -836,76 +3672,63 @@ def asset(self): self.exitRule() return localctx - class StepContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser def steptype(self): - return self.getTypedRuleContext(malParser.SteptypeContext,0) - + return self.getTypedRuleContext(malParser.SteptypeContext, 0) def ID(self): return self.getToken(malParser.ID, 0) - def tag(self, i:int=None): + def tag(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.TagContext) - else: - return self.getTypedRuleContext(malParser.TagContext,i) - + return self.getTypedRuleContext(malParser.TagContext, i) def cias(self): - return self.getTypedRuleContext(malParser.CiasContext,0) - + return self.getTypedRuleContext(malParser.CiasContext, 0) def ttc(self): - return self.getTypedRuleContext(malParser.TtcContext,0) - + return self.getTypedRuleContext(malParser.TtcContext, 0) - def meta(self, i:int=None): + def meta(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.MetaContext) - else: - return self.getTypedRuleContext(malParser.MetaContext,i) - + return self.getTypedRuleContext(malParser.MetaContext, i) def precondition(self): - return self.getTypedRuleContext(malParser.PreconditionContext,0) - + return self.getTypedRuleContext(malParser.PreconditionContext, 0) def reaches(self): - return self.getTypedRuleContext(malParser.ReachesContext,0) - + return self.getTypedRuleContext(malParser.ReachesContext, 0) def getRuleIndex(self): return malParser.RULE_step - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterStep" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterStep'): listener.enterStep(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitStep" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitStep'): listener.exitStep(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitStep" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitStep'): return visitor.visitStep(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def step(self): - localctx = malParser.StepContext(self, self._ctx, self.state) self.enterRule(localctx, 14, self.RULE_step) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 137 @@ -915,7 +3738,7 @@ def step(self): self.state = 142 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==38: + while _la == 38: self.state = 139 self.tag() self.state = 144 @@ -925,23 +3748,21 @@ def step(self): self.state = 146 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==19: + if _la == 19: self.state = 145 self.cias() - self.state = 149 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==25: + if _la == 25: self.state = 148 self.ttc() - self.state = 154 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==16: + while _la == 16: self.state = 151 self.meta() self.state = 156 @@ -951,19 +3772,17 @@ def step(self): self.state = 158 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==39: + if _la == 39: self.state = 157 self.precondition() - self.state = 161 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==40 or _la==41: + if _la in {40, 41}: self.state = 160 self.reaches() - except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -972,11 +3791,12 @@ def step(self): self.exitRule() return localctx - class SteptypeContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -998,33 +3818,28 @@ def NOTEXISTS(self): def getRuleIndex(self): return malParser.RULE_steptype - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterSteptype" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterSteptype'): listener.enterSteptype(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitSteptype" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitSteptype'): listener.exitSteptype(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitSteptype" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitSteptype'): return visitor.visitSteptype(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def steptype(self): - localctx = malParser.SteptypeContext(self, self._ctx, self.state) self.enterRule(localctx, 16, self.RULE_steptype) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 163 _la = self._input.LA(1) - if not((((_la) & ~0x3f) == 0 and ((1 << _la) & 240520269824) != 0)): + if not (((_la) & ~0x3F) == 0 and ((1 << _la) & 240520269824) != 0): self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -1037,11 +3852,12 @@ def steptype(self): self.exitRule() return localctx - class TagContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -1054,25 +3870,20 @@ def ID(self): def getRuleIndex(self): return malParser.RULE_tag - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTag" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTag'): listener.enterTag(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTag" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTag'): listener.exitTag(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTag" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTag'): return visitor.visitTag(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def tag(self): - localctx = malParser.TagContext(self, self._ctx, self.state) self.enterRule(localctx, 18, self.RULE_tag) try: @@ -1089,58 +3900,51 @@ def tag(self): self.exitRule() return localctx - class CiasContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser def LCURLY(self): return self.getToken(malParser.LCURLY, 0) - def cia(self, i:int=None): + def cia(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.CiaContext) - else: - return self.getTypedRuleContext(malParser.CiaContext,i) - + return self.getTypedRuleContext(malParser.CiaContext, i) def RCURLY(self): return self.getToken(malParser.RCURLY, 0) - def COMMA(self, i:int=None): + def COMMA(self, i: int | None = None): if i is None: return self.getTokens(malParser.COMMA) - else: - return self.getToken(malParser.COMMA, i) + return self.getToken(malParser.COMMA, i) def getRuleIndex(self): return malParser.RULE_cias - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterCias" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterCias'): listener.enterCias(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitCias" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitCias'): listener.exitCias(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCias" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitCias'): return visitor.visitCias(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def cias(self): - localctx = malParser.CiasContext(self, self._ctx, self.state) self.enterRule(localctx, 20, self.RULE_cias) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 168 @@ -1150,7 +3954,7 @@ def cias(self): self.state = 174 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==42: + while _la == 42: self.state = 170 self.match(malParser.COMMA) self.state = 171 @@ -1169,11 +3973,12 @@ def cias(self): self.exitRule() return localctx - class CiaContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -1189,33 +3994,28 @@ def A(self): def getRuleIndex(self): return malParser.RULE_cia - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterCia" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterCia'): listener.enterCia(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitCia" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitCia'): listener.exitCia(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitCia" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitCia'): return visitor.visitCia(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def cia(self): - localctx = malParser.CiaContext(self, self._ctx, self.state) self.enterRule(localctx, 22, self.RULE_cia) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 179 _la = self._input.LA(1) - if not((((_la) & ~0x3f) == 0 and ((1 << _la) & 57344) != 0)): + if not (((_la) & ~0x3F) == 0 and ((1 << _la) & 57344) != 0): self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -1228,11 +4028,12 @@ def cia(self): self.exitRule() return localctx - class TtcContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -1240,8 +4041,7 @@ def LSQUARE(self): return self.getToken(malParser.LSQUARE, 0) def ttcexpr(self): - return self.getTypedRuleContext(malParser.TtcexprContext,0) - + return self.getTypedRuleContext(malParser.TtcexprContext, 0) def RSQUARE(self): return self.getToken(malParser.RSQUARE, 0) @@ -1249,25 +4049,20 @@ def RSQUARE(self): def getRuleIndex(self): return malParser.RULE_ttc - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTtc" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTtc'): listener.enterTtc(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTtc" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTtc'): listener.exitTtc(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTtc" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTtc'): return visitor.visitTtc(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def ttc(self): - localctx = malParser.TtcContext(self, self._ctx, self.state) self.enterRule(localctx, 24, self.RULE_ttc) try: @@ -1286,58 +4081,50 @@ def ttc(self): self.exitRule() return localctx - class TtcexprContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def ttcterm(self, i:int=None): + def ttcterm(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.TtctermContext) - else: - return self.getTypedRuleContext(malParser.TtctermContext,i) + return self.getTypedRuleContext(malParser.TtctermContext, i) - - def PLUS(self, i:int=None): + def PLUS(self, i: int | None = None): if i is None: return self.getTokens(malParser.PLUS) - else: - return self.getToken(malParser.PLUS, i) + return self.getToken(malParser.PLUS, i) - def MINUS(self, i:int=None): + def MINUS(self, i: int | None = None): if i is None: return self.getTokens(malParser.MINUS) - else: - return self.getToken(malParser.MINUS, i) + return self.getToken(malParser.MINUS, i) def getRuleIndex(self): return malParser.RULE_ttcexpr - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTtcexpr" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTtcexpr'): listener.enterTtcexpr(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTtcexpr" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTtcexpr'): listener.exitTtcexpr(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTtcexpr" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTtcexpr'): return visitor.visitTtcexpr(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def ttcexpr(self): - localctx = malParser.TtcexprContext(self, self._ctx, self.state) self.enterRule(localctx, 26, self.RULE_ttcexpr) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 185 @@ -1345,10 +4132,10 @@ def ttcexpr(self): self.state = 190 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==30 or _la==43: + while _la in {30, 43}: self.state = 186 _la = self._input.LA(1) - if not(_la==30 or _la==43): + if _la not in {30, 43}: self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -1367,58 +4154,50 @@ def ttcexpr(self): self.exitRule() return localctx - class TtctermContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def ttcfact(self, i:int=None): + def ttcfact(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.TtcfactContext) - else: - return self.getTypedRuleContext(malParser.TtcfactContext,i) - + return self.getTypedRuleContext(malParser.TtcfactContext, i) - def STAR(self, i:int=None): + def STAR(self, i: int | None = None): if i is None: return self.getTokens(malParser.STAR) - else: - return self.getToken(malParser.STAR, i) + return self.getToken(malParser.STAR, i) - def DIVIDE(self, i:int=None): + def DIVIDE(self, i: int | None = None): if i is None: return self.getTokens(malParser.DIVIDE) - else: - return self.getToken(malParser.DIVIDE, i) + return self.getToken(malParser.DIVIDE, i) def getRuleIndex(self): return malParser.RULE_ttcterm - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTtcterm" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTtcterm'): listener.enterTtcterm(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTtcterm" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTtcterm'): listener.exitTtcterm(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTtcterm" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTtcterm'): return visitor.visitTtcterm(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def ttcterm(self): - localctx = malParser.TtctermContext(self, self._ctx, self.state) self.enterRule(localctx, 28, self.RULE_ttcterm) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 193 @@ -1426,10 +4205,10 @@ def ttcterm(self): self.state = 198 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==27 or _la==44: + while _la in {27, 44}: self.state = 194 _la = self._input.LA(1) - if not(_la==27 or _la==44): + if _la not in {27, 44}: self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -1448,20 +4227,19 @@ def ttcterm(self): self.exitRule() return localctx - class TtcfactContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def ttcatom(self, i:int=None): + def ttcatom(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.TtcatomContext) - else: - return self.getTypedRuleContext(malParser.TtcatomContext,i) - + return self.getTypedRuleContext(malParser.TtcatomContext, i) def POWER(self): return self.getToken(malParser.POWER, 0) @@ -1469,28 +4247,23 @@ def POWER(self): def getRuleIndex(self): return malParser.RULE_ttcfact - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTtcfact" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTtcfact'): listener.enterTtcfact(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTtcfact" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTtcfact'): listener.exitTtcfact(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTtcfact" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTtcfact'): return visitor.visitTtcfact(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def ttcfact(self): - localctx = malParser.TtcfactContext(self, self._ctx, self.state) self.enterRule(localctx, 30, self.RULE_ttcfact) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 201 @@ -1498,13 +4271,12 @@ def ttcfact(self): self.state = 204 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==45: + if _la == 45: self.state = 202 self.match(malParser.POWER) self.state = 203 self.ttcatom() - except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -1513,66 +4285,58 @@ def ttcfact(self): self.exitRule() return localctx - class TtcatomContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser def ttcdist(self): - return self.getTypedRuleContext(malParser.TtcdistContext,0) - + return self.getTypedRuleContext(malParser.TtcdistContext, 0) def LPAREN(self): return self.getToken(malParser.LPAREN, 0) def ttcexpr(self): - return self.getTypedRuleContext(malParser.TtcexprContext,0) - + return self.getTypedRuleContext(malParser.TtcexprContext, 0) def RPAREN(self): return self.getToken(malParser.RPAREN, 0) def number(self): - return self.getTypedRuleContext(malParser.NumberContext,0) - + return self.getTypedRuleContext(malParser.NumberContext, 0) def getRuleIndex(self): return malParser.RULE_ttcatom - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTtcatom" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTtcatom'): listener.enterTtcatom(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTtcatom" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTtcatom'): listener.exitTtcatom(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTtcatom" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTtcatom'): return visitor.visitTtcatom(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def ttcatom(self): - localctx = malParser.TtcatomContext(self, self._ctx, self.state) self.enterRule(localctx, 32, self.RULE_ttcatom) try: self.state = 212 self._errHandler.sync(self) token = self._input.LA(1) - if token in [16]: + if token == 16: self.enterOuterAlt(localctx, 1) self.state = 206 self.ttcdist() - pass - elif token in [17]: + elif token == 17: self.enterOuterAlt(localctx, 2) self.state = 207 self.match(malParser.LPAREN) @@ -1580,12 +4344,10 @@ def ttcatom(self): self.ttcexpr() self.state = 209 self.match(malParser.RPAREN) - pass - elif token in [10, 11]: + elif token in {10, 11}: self.enterOuterAlt(localctx, 3) self.state = 211 self.number() - pass else: raise NoViableAltException(self) @@ -1597,11 +4359,12 @@ def ttcatom(self): self.exitRule() return localctx - class TtcdistContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -1614,44 +4377,36 @@ def LPAREN(self): def RPAREN(self): return self.getToken(malParser.RPAREN, 0) - def number(self, i:int=None): + def number(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.NumberContext) - else: - return self.getTypedRuleContext(malParser.NumberContext,i) + return self.getTypedRuleContext(malParser.NumberContext, i) - - def COMMA(self, i:int=None): + def COMMA(self, i: int | None = None): if i is None: return self.getTokens(malParser.COMMA) - else: - return self.getToken(malParser.COMMA, i) + return self.getToken(malParser.COMMA, i) def getRuleIndex(self): return malParser.RULE_ttcdist - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterTtcdist" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterTtcdist'): listener.enterTtcdist(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitTtcdist" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitTtcdist'): listener.exitTtcdist(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitTtcdist" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitTtcdist'): return visitor.visitTtcdist(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def ttcdist(self): - localctx = malParser.TtcdistContext(self, self._ctx, self.state) self.enterRule(localctx, 34, self.RULE_ttcdist) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 214 @@ -1659,19 +4414,19 @@ def ttcdist(self): self.state = 227 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==17: + if _la == 17: self.state = 215 self.match(malParser.LPAREN) self.state = 224 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==10 or _la==11: + if _la in {10, 11}: self.state = 216 self.number() self.state = 221 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==42: + while _la == 42: self.state = 217 self.match(malParser.COMMA) self.state = 218 @@ -1680,12 +4435,9 @@ def ttcdist(self): self._errHandler.sync(self) _la = self._input.LA(1) - - self.state = 226 self.match(malParser.RPAREN) - except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -1694,55 +4446,48 @@ def ttcdist(self): self.exitRule() return localctx - class PreconditionContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser def REQUIRES(self): return self.getToken(malParser.REQUIRES, 0) - def expr(self, i:int=None): + def expr(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.ExprContext) - else: - return self.getTypedRuleContext(malParser.ExprContext,i) - + return self.getTypedRuleContext(malParser.ExprContext, i) - def COMMA(self, i:int=None): + def COMMA(self, i: int | None = None): if i is None: return self.getTokens(malParser.COMMA) - else: - return self.getToken(malParser.COMMA, i) + return self.getToken(malParser.COMMA, i) def getRuleIndex(self): return malParser.RULE_precondition - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterPrecondition" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterPrecondition'): listener.enterPrecondition(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitPrecondition" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitPrecondition'): listener.exitPrecondition(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitPrecondition" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitPrecondition'): return visitor.visitPrecondition(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def precondition(self): - localctx = malParser.PreconditionContext(self, self._ctx, self.state) self.enterRule(localctx, 36, self.RULE_precondition) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 229 @@ -1752,7 +4497,7 @@ def precondition(self): self.state = 235 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==42: + while _la == 42: self.state = 231 self.match(malParser.COMMA) self.state = 232 @@ -1769,20 +4514,19 @@ def precondition(self): self.exitRule() return localctx - class ReachesContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def expr(self, i:int=None): + def expr(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.ExprContext) - else: - return self.getTypedRuleContext(malParser.ExprContext,i) - + return self.getTypedRuleContext(malParser.ExprContext, i) def INHERITS(self): return self.getToken(malParser.INHERITS, 0) @@ -1790,42 +4534,36 @@ def INHERITS(self): def LEADSTO(self): return self.getToken(malParser.LEADSTO, 0) - def COMMA(self, i:int=None): + def COMMA(self, i: int | None = None): if i is None: return self.getTokens(malParser.COMMA) - else: - return self.getToken(malParser.COMMA, i) + return self.getToken(malParser.COMMA, i) def getRuleIndex(self): return malParser.RULE_reaches - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterReaches" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterReaches'): listener.enterReaches(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitReaches" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitReaches'): listener.exitReaches(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitReaches" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitReaches'): return visitor.visitReaches(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def reaches(self): - localctx = malParser.ReachesContext(self, self._ctx, self.state) self.enterRule(localctx, 38, self.RULE_reaches) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 238 _la = self._input.LA(1) - if not(_la==40 or _la==41): + if _la not in {40, 41}: self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -1835,7 +4573,7 @@ def reaches(self): self.state = 244 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==42: + while _la == 42: self.state = 240 self.match(malParser.COMMA) self.state = 241 @@ -1852,11 +4590,12 @@ def reaches(self): self.exitRule() return localctx - class NumberContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -1869,33 +4608,28 @@ def FLOAT(self): def getRuleIndex(self): return malParser.RULE_number - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterNumber" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterNumber'): listener.enterNumber(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitNumber" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitNumber'): listener.exitNumber(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitNumber" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitNumber'): return visitor.visitNumber(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def number(self): - localctx = malParser.NumberContext(self, self._ctx, self.state) self.enterRule(localctx, 40, self.RULE_number) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 247 _la = self._input.LA(1) - if not(_la==10 or _la==11): + if _la not in {10, 11}: self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -1908,11 +4642,12 @@ def number(self): self.exitRule() return localctx - class VariableContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -1926,31 +4661,25 @@ def ASSIGN(self): return self.getToken(malParser.ASSIGN, 0) def expr(self): - return self.getTypedRuleContext(malParser.ExprContext,0) - + return self.getTypedRuleContext(malParser.ExprContext, 0) def getRuleIndex(self): return malParser.RULE_variable - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterVariable" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterVariable'): listener.enterVariable(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitVariable" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitVariable'): listener.exitVariable(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVariable" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitVariable'): return visitor.visitVariable(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def variable(self): - localctx = malParser.VariableContext(self, self._ctx, self.state) self.enterRule(localctx, 42, self.RULE_variable) try: @@ -1971,53 +4700,45 @@ def variable(self): self.exitRule() return localctx - class ExprContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def parts(self, i:int=None): + def parts(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.PartsContext) - else: - return self.getTypedRuleContext(malParser.PartsContext,i) - + return self.getTypedRuleContext(malParser.PartsContext, i) - def setop(self, i:int=None): + def setop(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.SetopContext) - else: - return self.getTypedRuleContext(malParser.SetopContext,i) - + return self.getTypedRuleContext(malParser.SetopContext, i) def getRuleIndex(self): return malParser.RULE_expr - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterExpr" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterExpr'): listener.enterExpr(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitExpr" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitExpr'): listener.exitExpr(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitExpr" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitExpr'): return visitor.visitExpr(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def expr(self): - localctx = malParser.ExprContext(self, self._ctx, self.state) self.enterRule(localctx, 44, self.RULE_expr) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 254 @@ -2025,7 +4746,7 @@ def expr(self): self.state = 260 self._errHandler.sync(self) _la = self._input.LA(1) - while (((_la) & ~0x3f) == 0 and ((1 << _la) & 7516192768) != 0): + while ((_la) & ~0x3F) == 0 and ((1 << _la) & 7516192768) != 0: self.state = 255 self.setop() self.state = 256 @@ -2042,52 +4763,45 @@ def expr(self): self.exitRule() return localctx - class PartsContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def part(self, i:int=None): + def part(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.PartContext) - else: - return self.getTypedRuleContext(malParser.PartContext,i) + return self.getTypedRuleContext(malParser.PartContext, i) - - def DOT(self, i:int=None): + def DOT(self, i: int | None = None): if i is None: return self.getTokens(malParser.DOT) - else: - return self.getToken(malParser.DOT, i) + return self.getToken(malParser.DOT, i) def getRuleIndex(self): return malParser.RULE_parts - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterParts" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterParts'): listener.enterParts(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitParts" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitParts'): listener.exitParts(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitParts" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitParts'): return visitor.visitParts(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def parts(self): - localctx = malParser.PartsContext(self, self._ctx, self.state) self.enterRule(localctx, 46, self.RULE_parts) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 263 @@ -2095,7 +4809,7 @@ def parts(self): self.state = 268 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==34: + while _la == 34: self.state = 264 self.match(malParser.DOT) self.state = 265 @@ -2112,11 +4826,12 @@ def parts(self): self.exitRule() return localctx - class PartContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2124,15 +4839,13 @@ def LPAREN(self): return self.getToken(malParser.LPAREN, 0) def expr(self): - return self.getTypedRuleContext(malParser.ExprContext,0) - + return self.getTypedRuleContext(malParser.ExprContext, 0) def RPAREN(self): return self.getToken(malParser.RPAREN, 0) def varsubst(self): - return self.getTypedRuleContext(malParser.VarsubstContext,0) - + return self.getTypedRuleContext(malParser.VarsubstContext, 0) def ID(self): return self.getToken(malParser.ID, 0) @@ -2140,43 +4853,36 @@ def ID(self): def STAR(self): return self.getToken(malParser.STAR, 0) - def type_(self, i:int=None): + def type_(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.TypeContext) - else: - return self.getTypedRuleContext(malParser.TypeContext,i) - + return self.getTypedRuleContext(malParser.TypeContext, i) def getRuleIndex(self): return malParser.RULE_part - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterPart" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterPart'): listener.enterPart(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitPart" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitPart'): listener.exitPart(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitPart" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitPart'): return visitor.visitPart(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def part(self): - localctx = malParser.PartContext(self, self._ctx, self.state) self.enterRule(localctx, 48, self.RULE_part) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 280 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + la_ = self._interp.adaptivePredict(self._input, 28, self._ctx) if la_ == 1: self.state = 271 self.match(malParser.LPAREN) @@ -2184,7 +4890,6 @@ def part(self): self.expr() self.state = 273 self.match(malParser.RPAREN) - pass elif la_ == 2: self.state = 275 @@ -2193,26 +4898,22 @@ def part(self): self.match(malParser.LPAREN) self.state = 277 self.match(malParser.RPAREN) - pass elif la_ == 3: self.state = 279 self.match(malParser.ID) - pass - self.state = 283 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==27: + if _la == 27: self.state = 282 self.match(malParser.STAR) - self.state = 288 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==25: + while _la == 25: self.state = 285 self.type_() self.state = 290 @@ -2227,11 +4928,12 @@ def part(self): self.exitRule() return localctx - class VarsubstContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2241,25 +4943,20 @@ def ID(self): def getRuleIndex(self): return malParser.RULE_varsubst - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterVarsubst" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterVarsubst'): listener.enterVarsubst(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitVarsubst" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitVarsubst'): listener.exitVarsubst(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitVarsubst" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitVarsubst'): return visitor.visitVarsubst(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def varsubst(self): - localctx = malParser.VarsubstContext(self, self._ctx, self.state) self.enterRule(localctx, 50, self.RULE_varsubst) try: @@ -2274,11 +4971,12 @@ def varsubst(self): self.exitRule() return localctx - class TypeContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2294,25 +4992,20 @@ def RSQUARE(self): def getRuleIndex(self): return malParser.RULE_type - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterType" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterType'): listener.enterType(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitType" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitType'): listener.exitType(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitType" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitType'): return visitor.visitType(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def type_(self): - localctx = malParser.TypeContext(self, self._ctx, self.state) self.enterRule(localctx, 52, self.RULE_type) try: @@ -2331,11 +5024,12 @@ def type_(self): self.exitRule() return localctx - class SetopContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2351,33 +5045,28 @@ def MINUS(self): def getRuleIndex(self): return malParser.RULE_setop - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterSetop" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterSetop'): listener.enterSetop(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitSetop" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitSetop'): listener.exitSetop(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitSetop" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitSetop'): return visitor.visitSetop(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def setop(self): - localctx = malParser.SetopContext(self, self._ctx, self.state) self.enterRule(localctx, 54, self.RULE_setop) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 297 _la = self._input.LA(1) - if not((((_la) & ~0x3f) == 0 and ((1 << _la) & 7516192768) != 0)): + if not (((_la) & ~0x3F) == 0 and ((1 << _la) & 7516192768) != 0): self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -2390,11 +5079,12 @@ def setop(self): self.exitRule() return localctx - class AssociationsContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2407,38 +5097,31 @@ def LCURLY(self): def RCURLY(self): return self.getToken(malParser.RCURLY, 0) - def association(self, i:int=None): + def association(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.AssociationContext) - else: - return self.getTypedRuleContext(malParser.AssociationContext,i) - + return self.getTypedRuleContext(malParser.AssociationContext, i) def getRuleIndex(self): return malParser.RULE_associations - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterAssociations" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterAssociations'): listener.enterAssociations(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitAssociations" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitAssociations'): listener.exitAssociations(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAssociations" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitAssociations'): return visitor.visitAssociations(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def associations(self): - localctx = malParser.AssociationsContext(self, self._ctx, self.state) self.enterRule(localctx, 56, self.RULE_associations) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 299 @@ -2448,7 +5131,7 @@ def associations(self): self.state = 304 self._errHandler.sync(self) _la = self._input.LA(1) - while _la==16: + while _la == 16: self.state = 301 self.association() self.state = 306 @@ -2465,73 +5148,61 @@ def associations(self): self.exitRule() return localctx - class AssociationContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def ID(self, i:int=None): + def ID(self, i: int | None = None): if i is None: return self.getTokens(malParser.ID) - else: - return self.getToken(malParser.ID, i) + return self.getToken(malParser.ID, i) - def field(self, i:int=None): + def field(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.FieldContext) - else: - return self.getTypedRuleContext(malParser.FieldContext,i) - + return self.getTypedRuleContext(malParser.FieldContext, i) - def mult(self, i:int=None): + def mult(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.MultContext) - else: - return self.getTypedRuleContext(malParser.MultContext,i) - + return self.getTypedRuleContext(malParser.MultContext, i) def LARROW(self): return self.getToken(malParser.LARROW, 0) def linkname(self): - return self.getTypedRuleContext(malParser.LinknameContext,0) - + return self.getTypedRuleContext(malParser.LinknameContext, 0) def RARROW(self): return self.getToken(malParser.RARROW, 0) - def meta(self, i:int=None): + def meta(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.MetaContext) - else: - return self.getTypedRuleContext(malParser.MetaContext,i) - + return self.getTypedRuleContext(malParser.MetaContext, i) def getRuleIndex(self): return malParser.RULE_association - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterAssociation" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterAssociation'): listener.enterAssociation(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitAssociation" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitAssociation'): listener.exitAssociation(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitAssociation" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitAssociation'): return visitor.visitAssociation(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def association(self): - localctx = malParser.AssociationContext(self, self._ctx, self.state) self.enterRule(localctx, 58, self.RULE_association) try: @@ -2556,14 +5227,14 @@ def association(self): self.match(malParser.ID) self.state = 321 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,32,self._ctx) - while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: - if _alt==1: + _alt = self._interp.adaptivePredict(self._input, 32, self._ctx) + while _alt not in {2, ATN.INVALID_ALT_NUMBER}: + if _alt == 1: self.state = 318 - self.meta() + self.meta() self.state = 323 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,32,self._ctx) + _alt = self._interp.adaptivePredict(self._input, 32, self._ctx) except RecognitionException as re: localctx.exception = re @@ -2573,11 +5244,12 @@ def association(self): self.exitRule() return localctx - class FieldContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2593,25 +5265,20 @@ def RSQUARE(self): def getRuleIndex(self): return malParser.RULE_field - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterField" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterField'): listener.enterField(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitField" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitField'): listener.exitField(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitField" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitField'): return visitor.visitField(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def field(self): - localctx = malParser.FieldContext(self, self._ctx, self.state) self.enterRule(localctx, 60, self.RULE_field) try: @@ -2630,20 +5297,19 @@ def field(self): self.exitRule() return localctx - class MultContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser - def multatom(self, i:int=None): + def multatom(self, i: int | None = None): if i is None: return self.getTypedRuleContexts(malParser.MultatomContext) - else: - return self.getTypedRuleContext(malParser.MultatomContext,i) - + return self.getTypedRuleContext(malParser.MultatomContext, i) def RANGE(self): return self.getToken(malParser.RANGE, 0) @@ -2651,28 +5317,23 @@ def RANGE(self): def getRuleIndex(self): return malParser.RULE_mult - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterMult" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterMult'): listener.enterMult(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitMult" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitMult'): listener.exitMult(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMult" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitMult'): return visitor.visitMult(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def mult(self): - localctx = malParser.MultContext(self, self._ctx, self.state) self.enterRule(localctx, 62, self.RULE_mult) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 328 @@ -2680,13 +5341,12 @@ def mult(self): self.state = 331 self._errHandler.sync(self) _la = self._input.LA(1) - if _la==33: + if _la == 33: self.state = 329 self.match(malParser.RANGE) self.state = 330 self.multatom() - except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -2695,11 +5355,12 @@ def mult(self): self.exitRule() return localctx - class MultatomContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2712,33 +5373,28 @@ def STAR(self): def getRuleIndex(self): return malParser.RULE_multatom - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterMultatom" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterMultatom'): listener.enterMultatom(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitMultatom" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitMultatom'): listener.exitMultatom(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitMultatom" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitMultatom'): return visitor.visitMultatom(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def multatom(self): - localctx = malParser.MultatomContext(self, self._ctx, self.state) self.enterRule(localctx, 64, self.RULE_multatom) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) self.state = 333 _la = self._input.LA(1) - if not(_la==10 or _la==27): + if _la not in {10, 27}: self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) @@ -2751,11 +5407,12 @@ def multatom(self): self.exitRule() return localctx - class LinknameContext(ParserRuleContext): __slots__ = 'parser' - def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + def __init__( + self, parser, parent: ParserRuleContext = None, invokingState: int = -1 + ) -> None: super().__init__(parent, invokingState) self.parser = parser @@ -2765,25 +5422,20 @@ def ID(self): def getRuleIndex(self): return malParser.RULE_linkname - def enterRule(self, listener:ParseTreeListener): - if hasattr( listener, "enterLinkname" ): + def enterRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'enterLinkname'): listener.enterLinkname(self) - def exitRule(self, listener:ParseTreeListener): - if hasattr( listener, "exitLinkname" ): + def exitRule(self, listener: ParseTreeListener) -> None: + if hasattr(listener, 'exitLinkname'): listener.exitLinkname(self) - def accept(self, visitor:ParseTreeVisitor): - if hasattr( visitor, "visitLinkname" ): + def accept(self, visitor: ParseTreeVisitor): + if hasattr(visitor, 'visitLinkname'): return visitor.visitLinkname(self) - else: - return visitor.visitChildren(self) - - - + return visitor.visitChildren(self) def linkname(self): - localctx = malParser.LinknameContext(self, self._ctx, self.state) self.enterRule(localctx, 66, self.RULE_linkname) try: @@ -2797,8 +5449,3 @@ def linkname(self): finally: self.exitRule() return localctx - - - - - diff --git a/maltoolbox/language/compiler/mal_visitor.py b/maltoolbox/language/compiler/mal_visitor.py index d94d9f1a..0cbcf7f5 100644 --- a/maltoolbox/language/compiler/mal_visitor.py +++ b/maltoolbox/language/compiler/mal_visitor.py @@ -2,14 +2,16 @@ from collections.abc import MutableMapping, MutableSequence from antlr4 import ParseTreeVisitor + from .mal_parser import malParser # In a rule like `rule: one? two* three`: # - ctx.one() would be None if the token was not found on a matching line # - ctx.two() would be [] + class malVisitor(ParseTreeVisitor): - def __init__(self, compiler, *args, **kwargs): + def __init__(self, compiler, *args, **kwargs) -> None: self.compiler = compiler self.current_file = compiler.current_file # for debug purposes @@ -17,11 +19,11 @@ def __init__(self, compiler, *args, **kwargs): def visitMal(self, ctx): langspec = { - "formatVersion": "1.0.0", - "defines": {}, - "categories": [], - "assets": [], - "associations": [], + 'formatVersion': '1.0.0', + 'defines': {}, + 'categories': [], + 'assets': [], + 'associations': [], } # no visitDeclaration method needed, `declaration` is a thin rule @@ -29,27 +31,27 @@ def visitMal(self, ctx): if result := self.visit(declaration) or True: key, value = result - if key == "categories": + if key == 'categories': category, assets = value - langspec["categories"].extend(category) - langspec["assets"].extend(assets) + langspec['categories'].extend(category) + langspec['assets'].extend(assets) continue - if key == "defines": + if key == 'defines': langspec[key].update(value) - if key == "associations": + if key == 'associations': langspec[key].extend(value) - if key == "include": + if key == 'include': included_file = self.compiler.compile(value) for k, v in langspec.items(): if isinstance(v, MutableMapping): - langspec[k].update(included_file.get(k, {})) + v.update(included_file.get(k, {})) if isinstance(v, MutableSequence) and k in included_file: - langspec[k].extend(included_file[k]) + v.extend(included_file[k]) - for key in ("categories", "assets", "associations"): + for key in ('categories', 'assets', 'associations'): unique = [] for item in langspec[key]: if item not in unique: @@ -59,65 +61,65 @@ def visitMal(self, ctx): return langspec def visitInclude(self, ctx): - return ("include", ctx.STRING().getText().strip('"')) + return ('include', ctx.STRING().getText().strip('"')) def visitDefine(self, ctx): - return ("defines", {ctx.ID().getText(): ctx.STRING().getText().strip('"')}) + return ('defines', {ctx.ID().getText(): ctx.STRING().getText().strip('"')}) def visitCategory(self, ctx): category = {} - category["name"] = ctx.ID().getText() - category["meta"] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} + category['name'] = ctx.ID().getText() + category['meta'] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} assets = [self.visit(asset) for asset in ctx.asset()] - return ("categories", ([category], assets)) + return ('categories', ([category], assets)) def visitMeta(self, ctx): return ((ctx.ID().getText(), ctx.STRING().getText().strip('"')),) def visitAsset(self, ctx): asset = {} - asset["name"] = ctx.ID()[0].getText() - asset["meta"] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} - asset["category"] = ctx.parentCtx.ID().getText() - asset["isAbstract"] = ctx.ABSTRACT() is not None + asset['name'] = ctx.ID()[0].getText() + asset['meta'] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} + asset['category'] = ctx.parentCtx.ID().getText() + asset['isAbstract'] = ctx.ABSTRACT() is not None - asset["superAsset"] = None + asset['superAsset'] = None if len(ctx.ID()) > 1 and ctx.ID()[1]: - asset["superAsset"] = ctx.ID()[1].getText() + asset['superAsset'] = ctx.ID()[1].getText() - asset["variables"] = [self.visit(variable) for variable in ctx.variable()] - asset["attackSteps"] = [self.visit(step) for step in ctx.step()] + asset['variables'] = [self.visit(variable) for variable in ctx.variable()] + asset['attackSteps'] = [self.visit(step) for step in ctx.step()] return asset def visitStep(self, ctx): step = {} - step["name"] = ctx.ID().getText() - step["meta"] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} - step["type"] = self.visit(ctx.steptype()) - step["tags"] = [self.visit(tag) for tag in ctx.tag()] - step["risk"] = self.visit(ctx.cias()) if ctx.cias() else None - step["ttc"] = self.visit(ctx.ttc()) if ctx.ttc() else None - step["requires"] = ( + step['name'] = ctx.ID().getText() + step['meta'] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} + step['type'] = self.visit(ctx.steptype()) + step['tags'] = [self.visit(tag) for tag in ctx.tag()] + step['risk'] = self.visit(ctx.cias()) if ctx.cias() else None + step['ttc'] = self.visit(ctx.ttc()) if ctx.ttc() else None + step['requires'] = ( self.visit(ctx.precondition()) if ctx.precondition() else None ) - step["reaches"] = self.visit(ctx.reaches()) if ctx.reaches() else None + step['reaches'] = self.visit(ctx.reaches()) if ctx.reaches() else None return step - def visitSteptype(self, ctx): + def visitSteptype(self, ctx) -> str | None: return ( - "or" + 'or' if ctx.OR() - else "and" + else 'and' if ctx.AND() - else "defense" + else 'defense' if ctx.HASH() - else "exist" + else 'exist' if ctx.EXISTS() - else "notExist" + else 'notExist' if ctx.NOTEXISTS() else None # should never happen, the grammar limits it ) @@ -127,9 +129,9 @@ def visitTag(self, ctx): def visitCias(self, ctx): risk = { - "isConfidentiality": False, - "isIntegrity": False, - "isAvailability": False, + 'isConfidentiality': False, + 'isIntegrity': False, + 'isAvailability': False, } for cia in ctx.cia(): @@ -139,11 +141,11 @@ def visitCias(self, ctx): def visitCia(self, ctx): key = ( - "isConfidentiality" + 'isConfidentiality' if ctx.C() - else "isIntegrity" + else 'isIntegrity' if ctx.I() - else "isAvailability" + else 'isAvailability' if ctx.A() else None ) @@ -151,9 +153,7 @@ def visitCia(self, ctx): return {key: True} def visitTtc(self, ctx): - ret = self.visit(ctx.ttcexpr()) - - return ret + return self.visit(ctx.ttcexpr()) def visitTtcexpr(self, ctx): if len(terms := ctx.ttcterm()) == 1: @@ -163,13 +163,13 @@ def visitTtcexpr(self, ctx): lhs = self.visit(terms[0]) for i in range(1, len(terms)): - ret["type"] = ( - "addition" - if ctx.children[2 * i - 1].getText() == "+" - else "subtraction" + ret['type'] = ( + 'addition' + if ctx.children[2 * i - 1].getText() == '+' + else 'subtraction' ) - ret["lhs"] = lhs - ret["rhs"] = self.visit(terms[i]) + ret['lhs'] = lhs + ret['rhs'] = self.visit(terms[i]) lhs = ret.copy() @@ -180,9 +180,9 @@ def visitTtcterm(self, ctx): ret = self.visit(factors[0]) else: ret = {} - ret["type"] = "multiplication" if ctx.STAR() else "division" - ret["lhs"] = self.visit(factors[0]) - ret["rhs"] = self.visit(factors[1]) + ret['type'] = 'multiplication' if ctx.STAR() else 'division' + ret['lhs'] = self.visit(factors[0]) + ret['rhs'] = self.visit(factors[1]) return ret @@ -191,9 +191,9 @@ def visitTtcfact(self, ctx): ret = self.visit(atoms[0]) else: ret = {} - ret["type"] = "exponentiation" - ret["lhs"] = self.visit(atoms[0]) - ret["rhs"] = self.visit(atoms[1]) + ret['type'] = 'exponentiation' + ret['lhs'] = self.visit(atoms[0]) + ret['rhs'] = self.visit(atoms[1]) return ret @@ -208,38 +208,38 @@ def visitTtcatom(self, ctx): return ret def visitTtcdist(self, ctx): - ret = {"type": "function"} - ret["name"] = ctx.ID().getText() - ret["arguments"] = [] + ret = {'type': 'function'} + ret['name'] = ctx.ID().getText() + ret['arguments'] = [] if ctx.LPAREN(): - ret["arguments"] = [self.visit(number)["value"] for number in ctx.number()] + ret['arguments'] = [self.visit(number)['value'] for number in ctx.number()] return ret def visitPrecondition(self, ctx): ret = {} - ret["overrides"] = True - ret["stepExpressions"] = [self.visit(expr) for expr in ctx.expr()] + ret['overrides'] = True + ret['stepExpressions'] = [self.visit(expr) for expr in ctx.expr()] return ret def visitReaches(self, ctx): ret = {} - ret["overrides"] = ctx.INHERITS() is None - ret["stepExpressions"] = [self.visit(expr) for expr in ctx.expr()] + ret['overrides'] = ctx.INHERITS() is None + ret['stepExpressions'] = [self.visit(expr) for expr in ctx.expr()] return ret def visitNumber(self, ctx): - ret = {"type": "number"} - ret["value"] = float(ctx.getText()) + ret = {'type': 'number'} + ret['value'] = float(ctx.getText()) return ret def visitVariable(self, ctx): ret = {} - ret["name"] = ctx.ID().getText() - ret["stepExpression"] = self.visit(ctx.expr()) + ret['name'] = ctx.ID().getText() + ret['stepExpression'] = self.visit(ctx.expr()) return ret @@ -250,9 +250,9 @@ def visitExpr(self, ctx): ret = {} lhs = self.visit(ctx.parts()[0]) for i in range(1, len(ctx.parts())): - ret["type"] = self.visit(ctx.children[2 * i - 1]) - ret["lhs"] = lhs - ret["rhs"] = self.visit(ctx.parts()[i]) + ret['type'] = self.visit(ctx.children[2 * i - 1]) + ret['lhs'] = lhs + ret['rhs'] = self.visit(ctx.parts()[i]) lhs = ret.copy() return ret @@ -266,9 +266,9 @@ def visitParts(self, ctx): lhs = self.visit(ctx.part()[0]) for i in range(1, len(ctx.part())): - ret["type"] = "collect" - ret["lhs"] = lhs - ret["rhs"] = self.visit(ctx.part()[i]) + ret['type'] = 'collect' + ret['lhs'] = lhs + ret['rhs'] = self.visit(ctx.part()[i]) lhs = ret.copy() @@ -277,36 +277,36 @@ def visitParts(self, ctx): def visitPart(self, ctx): ret = {} if ctx.varsubst(): - ret["type"] = "variable" - ret["name"] = self.visit(ctx.varsubst()) + ret['type'] = 'variable' + ret['name'] = self.visit(ctx.varsubst()) elif ctx.LPAREN(): ret = self.visit(ctx.expr()) else: # ctx.ID() # Resolve type: field or attackStep? - ret["type"] = self._resolve_part_ID_type(ctx) + ret['type'] = self._resolve_part_ID_type(ctx) - ret["name"] = ctx.ID().getText() + ret['name'] = ctx.ID().getText() if ctx.STAR(): - ret = {"type": "transitive", "stepExpression": ret} + ret = {'type': 'transitive', 'stepExpression': ret} for type_ in ctx.type_(): # mind the trailing underscore ret = { - "type": "subType", - "subType": self.visit(type_), - "stepExpression": ret, + 'type': 'subType', + 'subType': self.visit(type_), + 'stepExpression': ret, } return ret - def _resolve_part_ID_type(self, ctx): + def _resolve_part_ID_type(self, ctx) -> str: pctx = ctx.parentCtx # Traverse up the tree until we find the parent of the topmost expr # (saying "topmost" as expr can be nested) or the root of the tree. while pctx and not isinstance( pctx, - malParser.ReachesContext + malParser.ReachesContext, # Expressions are also valid in `let` variable assignments, but # there every lexical component of expr is considered a "field", # no need to resolve the type in that case. Similarly, preconditions @@ -316,21 +316,21 @@ def _resolve_part_ID_type(self, ctx): if pctx is None: # ctx (the `part`) belongs to a "let" assignment or a precondition. - return "field" + return 'field' # scan for a dot to the right of `ctx` file_tokens = ctx.parser.getTokenStream().tokens for i in range(ctx.start.tokenIndex, pctx.stop.tokenIndex + 1): if file_tokens[i].type == malParser.DOT: - return "field" + return 'field' # We are looping until the end of pctx (which is a `reaches` or # `precondition` context). This could include multiple comma # separated `expr`s, we only care for the current one. if file_tokens[i].type == malParser.COMMA: # end of current `expr` - return "attackStep" + return 'attackStep' - return "attackStep" + return 'attackStep' def visitVarsubst(self, ctx): return ctx.ID().getText() @@ -338,67 +338,65 @@ def visitVarsubst(self, ctx): def visitType(self, ctx): return ctx.ID().getText() - def visitSetop(self, ctx): + def visitSetop(self, ctx) -> str | None: return ( - "union" + 'union' if ctx.UNION() - else "intersection" + else 'intersection' if ctx.INTERSECT() - else "difference" + else 'difference' if ctx.INTERSECT else None ) def visitAssociations(self, ctx): - associations = [] - for assoc in ctx.association(): - associations.append(self.visit(assoc)) + associations = [self.visit(assoc) for assoc in ctx.association()] - return ("associations", associations) + return ('associations', associations) def visitAssociation(self, ctx): association = {} - association["name"] = self.visit(ctx.linkname()) - association["meta"] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} - association["leftAsset"] = ctx.ID()[0].getText() - association["leftField"] = self.visit(ctx.field()[0]) + association['name'] = self.visit(ctx.linkname()) + association['meta'] = {k: v for meta in ctx.meta() for k, v in self.visit(meta)} + association['leftAsset'] = ctx.ID()[0].getText() + association['leftField'] = self.visit(ctx.field()[0]) # no self.visitMult or self.visitMultatom methods, reading them here # directly - association["leftMultiplicity"] = { - "min": (multatoms := ctx.mult()[0].multatom()).pop(0).getText(), - "max": multatoms.pop().getText() if multatoms else None, + association['leftMultiplicity'] = { + 'min': (multatoms := ctx.mult()[0].multatom()).pop(0).getText(), + 'max': multatoms.pop().getText() if multatoms else None, } - association["rightAsset"] = ctx.ID()[1].getText() - association["rightField"] = self.visit(ctx.field()[1]) - association["rightMultiplicity"] = { - "min": (multatoms := ctx.mult()[1].multatom()).pop(0).getText(), - "max": multatoms.pop().getText() if multatoms else None, + association['rightAsset'] = ctx.ID()[1].getText() + association['rightField'] = self.visit(ctx.field()[1]) + association['rightMultiplicity'] = { + 'min': (multatoms := ctx.mult()[1].multatom()).pop(0).getText(), + 'max': multatoms.pop().getText() if multatoms else None, } self._post_process_multitudes(association) return association - def _post_process_multitudes(self, association): + def _post_process_multitudes(self, association) -> None: mult_keys = [ # start the multatoms from right to left to make sure the rules # below get applied cleanly - "rightMultiplicity.max", - "rightMultiplicity.min", - "leftMultiplicity.max", - "leftMultiplicity.min", + 'rightMultiplicity.max', + 'rightMultiplicity.min', + 'leftMultiplicity.max', + 'leftMultiplicity.min', ] for mult_key in mult_keys: - key, subkey = mult_key.split(".") + key, subkey = mult_key.split('.') # upper limit equals lower limit if not given - if subkey == "max" and association[key][subkey] is None: - association[key][subkey] = association[key]["min"] + if subkey == 'max' and association[key][subkey] is None: + association[key][subkey] = association[key]['min'] - if association[key][subkey] == "*": + if association[key][subkey] == '*': # 'any' as lower limit means start from 0 - if subkey == "min": + if subkey == 'min': association[key][subkey] = 0 # 'any' as upper limit means not limit diff --git a/maltoolbox/language/languagegraph.py b/maltoolbox/language/languagegraph.py index c99fc8d5..f69d7ea9 100644 --- a/maltoolbox/language/languagegraph.py +++ b/maltoolbox/language/languagegraph.py @@ -1,90 +1,81 @@ -""" -MAL-Toolbox Language Graph Module -""" +"""MAL-Toolbox Language Graph Module.""" from __future__ import annotations -import logging import json +import logging import zipfile - from dataclasses import dataclass, field from functools import cached_property -from typing import Any, Optional +from typing import Any -from maltoolbox.file_utils import ( - load_dict_from_yaml_file, load_dict_from_json_file, - save_dict_to_file -) -from .compiler import MalCompiler -from ..exceptions import ( +from maltoolbox.exceptions import ( LanguageGraphAssociationError, - LanguageGraphStepExpressionError, LanguageGraphException, - LanguageGraphSuperAssetNotFoundError + LanguageGraphStepExpressionError, + LanguageGraphSuperAssetNotFoundError, ) +from maltoolbox.file_utils import ( + load_dict_from_json_file, + load_dict_from_yaml_file, + save_dict_to_file, +) + +from .compiler import MalCompiler logger = logging.getLogger(__name__) -def disaggregate_attack_step_full_name( - attack_step_full_name: str) -> list[str]: +def disaggregate_attack_step_full_name(attack_step_full_name: str) -> list[str]: return attack_step_full_name.split(':') @dataclass class LanguageGraphAsset: name: str - own_associations: dict[str, LanguageGraphAssociation] = \ - field(default_factory = dict) - attack_steps: dict[str, LanguageGraphAttackStep] = \ - field(default_factory = dict) - info: dict = field(default_factory = dict) - own_super_asset: Optional[LanguageGraphAsset] = None - own_sub_assets: set[LanguageGraphAsset] = field(default_factory = set) - own_variables: dict = field(default_factory = dict) - is_abstract: Optional[bool] = None - + own_associations: dict[str, LanguageGraphAssociation] = field(default_factory=dict) + attack_steps: dict[str, LanguageGraphAttackStep] = field(default_factory=dict) + info: dict = field(default_factory=dict) + own_super_asset: LanguageGraphAsset | None = None + own_sub_assets: set[LanguageGraphAsset] = field(default_factory=set) + own_variables: dict = field(default_factory=dict) + is_abstract: bool | None = None def to_dict(self) -> dict: - """Convert LanguageGraphAsset to dictionary""" + """Convert LanguageGraphAsset to dictionary.""" node_dict: dict[str, Any] = { 'name': self.name, 'associations': {}, 'attack_steps': {}, 'info': self.info, - 'super_asset': self.own_super_asset.name \ - if self.own_super_asset else "", + 'super_asset': self.own_super_asset.name if self.own_super_asset else '', 'sub_assets': [asset.name for asset in self.own_sub_assets], 'variables': {}, - 'is_abstract': self.is_abstract + 'is_abstract': self.is_abstract, } for fieldname, assoc in self.own_associations.items(): node_dict['associations'][fieldname] = assoc.to_dict() for attack_step in self.attack_steps.values(): - node_dict['attack_steps'][attack_step.name] = \ - attack_step.to_dict() - for variable_name, (var_target_asset, var_expr_chain) in \ - self.own_variables.items(): + node_dict['attack_steps'][attack_step.name] = attack_step.to_dict() + for variable_name, ( + var_target_asset, + var_expr_chain, + ) in self.own_variables.items(): node_dict['variables'][variable_name] = ( var_target_asset.name, - var_expr_chain.to_dict() + var_expr_chain.to_dict(), ) return node_dict - def __repr__(self) -> str: return str(self.to_dict()) - def __hash__(self): return hash(self.name) - def is_subasset_of(self, target_asset: LanguageGraphAsset) -> bool: - """ - Check if an asset extends the target asset through inheritance. + """Check if an asset extends the target asset through inheritance. Arguments: target_asset - the target asset we wish to evaluate if this asset @@ -93,9 +84,10 @@ def is_subasset_of(self, target_asset: LanguageGraphAsset) -> bool: Return: True if this asset extends the target_asset via inheritance. False otherwise. + """ - current_asset: Optional[LanguageGraphAsset] = self - while (current_asset): + current_asset: LanguageGraphAsset | None = self + while current_asset: if current_asset == target_asset: return True current_asset = current_asset.own_super_asset @@ -103,12 +95,12 @@ def is_subasset_of(self, target_asset: LanguageGraphAsset) -> bool: @cached_property def sub_assets(self) -> set[LanguageGraphAsset]: - """ - Return a list of all of the assets that directly or indirectly extend + """Return a list of all of the assets that directly or indirectly extend this asset. Return: A list of all of the assets that extend this asset plus itself. + """ subassets: list[LanguageGraphAsset] = [] for subasset in self.own_sub_assets: @@ -119,90 +111,78 @@ def sub_assets(self) -> set[LanguageGraphAsset]: return set(subassets) - @cached_property def super_assets(self) -> list[LanguageGraphAsset]: - """ - Return a list of all of the assets that this asset directly or + """Return a list of all of the assets that this asset directly or indirectly extends. Return: A list of all of the assets that this asset extends plus itself. + """ - current_asset: Optional[LanguageGraphAsset] = self + current_asset: LanguageGraphAsset | None = self superassets = [] - while (current_asset): + while current_asset: superassets.append(current_asset) current_asset = current_asset.own_super_asset return superassets - @cached_property def associations(self) -> dict[str, LanguageGraphAssociation]: - """ - Return a list of all of the associations that belong to this asset + """Return a list of all of the associations that belong to this asset directly or indirectly via inheritance. Return: A list of all of the associations that apply to this asset, either directly or via inheritance. - """ + """ associations = dict(self.own_associations) if self.own_super_asset: associations |= self.own_super_asset.associations return associations - @property def variables(self) -> dict[str, ExpressionsChain]: - """ - Return a list of all of the variables that belong to this asset + """Return a list of all of the variables that belong to this asset directly or indirectly via inheritance. Return: A list of all of the variables that apply to this asset, either directly or via inheritance. - """ + """ all_vars = dict(self.own_variables) if self.own_super_asset: all_vars |= self.own_super_asset.variables return all_vars - def get_variable( self, var_name: str, - ) -> Optional[tuple]: - """ - Return a variable matching the given name if the asset or any of its + ) -> tuple | None: + """Return a variable matching the given name if the asset or any of its super assets has its definition. Return: A tuple containing the target asset and expressions chain to it if the variable was defined. None otherwise. + """ - current_asset: Optional[LanguageGraphAsset] = self - while (current_asset): + current_asset: LanguageGraphAsset | None = self + while current_asset: if var_name in current_asset.own_variables: return current_asset.own_variables[var_name] current_asset = current_asset.own_super_asset return None - - def get_all_common_superassets( - self, other: LanguageGraphAsset - ) -> set[str]: + def get_all_common_superassets(self, other: LanguageGraphAsset) -> set[str]: """Return a set of all common ancestors between this asset - and the other asset given as parameter""" - self_superassets = set( - asset.name for asset in self.super_assets - ) - other_superassets = set( - asset.name for asset in other.super_assets - ) + and the other asset given as parameter. + """ + self_superassets = {asset.name for asset in self.super_assets} + other_superassets = {asset.name for asset in other.super_assets} return self_superassets.intersection(other_superassets) @@ -219,68 +199,53 @@ class LanguageGraphAssociation: name: str left_field: LanguageGraphAssociationField right_field: LanguageGraphAssociationField - info: dict = field(default_factory = dict) + info: dict = field(default_factory=dict) def to_dict(self) -> dict: - """Convert LanguageGraphAssociation to dictionary""" - assoc_dict = { + """Convert LanguageGraphAssociation to dictionary.""" + return { 'name': self.name, 'info': self.info, 'left': { 'asset': self.left_field.asset.name, 'fieldname': self.left_field.fieldname, 'min': self.left_field.minimum, - 'max': self.left_field.maximum + 'max': self.left_field.maximum, }, 'right': { 'asset': self.right_field.asset.name, 'fieldname': self.right_field.fieldname, 'min': self.right_field.minimum, - 'max': self.right_field.maximum - } + 'max': self.right_field.maximum, + }, } - return assoc_dict - - def __repr__(self) -> str: return str(self.to_dict()) - @property def full_name(self) -> str: - """ - Return the full name of the association. This is a combination of the + """Return the full name of the association. This is a combination of the association name, left field name, left asset type, right field name, and right asset type. """ - full_name = '%s_%s_%s' % ( - self.name,\ - self.left_field.fieldname,\ - self.right_field.fieldname - ) - return full_name - + return f'{self.name}_{self.left_field.fieldname}_{self.right_field.fieldname}' def contains_fieldname(self, fieldname: str) -> bool: - """ - Check if the association contains the field name given as a parameter. + """Check if the association contains the field name given as a parameter. Arguments: fieldname - the field name to look for Return True if either of the two field names matches. False, otherwise. + """ if self.left_field.fieldname == fieldname: return True - if self.right_field.fieldname == fieldname: - return True - return False - + return self.right_field.fieldname == fieldname def contains_asset(self, asset: Any) -> bool: - """ - Check if the association matches the asset given as a parameter. A + """Check if the association matches the asset given as a parameter. A match can either be an explicit one or if the asset given subassets either of the two assets that are part of the association. @@ -288,31 +253,28 @@ def contains_asset(self, asset: Any) -> bool: asset - the asset to look for Return True if either of the two asset matches. False, otherwise. + """ if asset.is_subasset_of(self.left_field.asset): return True - if asset.is_subasset_of(self.right_field.asset): - return True - return False - + return bool(asset.is_subasset_of(self.right_field.asset)) def get_opposite_fieldname(self, fieldname: str) -> str: - """ - Return the opposite field name if the association contains the field + """Return the opposite field name if the association contains the field name given as a parameter. Arguments: fieldname - the field name to look for Return the other field name if the parameter matched either of the two. None, otherwise. + """ if self.left_field.fieldname == fieldname: return self.right_field.fieldname if self.right_field.fieldname == fieldname: return self.left_field.fieldname - msg = ('Requested fieldname "%s" from association ' - '%s which did not contain it!') + msg = 'Requested fieldname "%s" from association %s which did not contain it!' logger.error(msg, fieldname, self.name) raise LanguageGraphAssociationError(msg % (fieldname, self.name)) @@ -322,29 +284,25 @@ class LanguageGraphAttackStep: name: str type: str asset: LanguageGraphAsset - ttc: dict = field(default_factory = dict) + ttc: dict = field(default_factory=dict) overrides: bool = False - children: dict = field(default_factory = dict) - parents: dict = field(default_factory = dict) - info: dict = field(default_factory = dict) - inherits: Optional[LanguageGraphAttackStep] = None - tags: set = field(default_factory = set) - _attributes: Optional[dict] = None - + children: dict = field(default_factory=dict) + parents: dict = field(default_factory=dict) + info: dict = field(default_factory=dict) + inherits: LanguageGraphAttackStep | None = None + tags: set = field(default_factory=set) + _attributes: dict | None = None def __hash__(self): return hash(self.full_name) - @property def full_name(self) -> str: - """ - Return the full name of the attack step. This is a combination of the + """Return the full name of the attack step. This is a combination of the asset type name to which the attack step belongs and attack step name itself. """ - full_name = self.asset.name + ':' + self.name - return full_name + return self.asset.name + ':' + self.name def to_dict(self) -> dict: node_dict: dict[Any, Any] = { @@ -357,24 +315,22 @@ def to_dict(self) -> dict: 'info': self.info, 'overrides': self.overrides, 'inherits': self.inherits.full_name if self.inherits else None, - 'tags': list(self.tags) + 'tags': list(self.tags), } for child in self.children: node_dict['children'][child] = [] - for (_, expr_chain) in self.children[child]: + for _, expr_chain in self.children[child]: if expr_chain: - node_dict['children'][child].append( - expr_chain.to_dict()) + node_dict['children'][child].append(expr_chain.to_dict()) else: node_dict['children'][child].append(None) for parent in self.parents: node_dict['parents'][parent] = [] - for (_, expr_chain) in self.parents[parent]: + for _, expr_chain in self.parents[parent]: if expr_chain: - node_dict['parents'][parent].append( - expr_chain.to_dict()) + node_dict['parents'][parent].append(expr_chain.to_dict()) else: node_dict['parents'][parent].append(None) @@ -385,107 +341,86 @@ def to_dict(self) -> dict: return node_dict - @cached_property def requires(self): - if not hasattr(self, 'own_requires'): - requirements = [] - else: - requirements = self.own_requires + requirements = [] if not hasattr(self, 'own_requires') else self.own_requires if self.inherits: requirements.extend(self.inherits.requires) return requirements - def __repr__(self) -> str: return str(self.to_dict()) class ExpressionsChain: - def __init__(self, - type: str, - left_link: Optional[ExpressionsChain] = None, - right_link: Optional[ExpressionsChain] = None, - sub_link: Optional[ExpressionsChain] = None, - fieldname: Optional[str] = None, - association = None, - subtype = None - ): + def __init__( + self, + type: str, + left_link: ExpressionsChain | None = None, + right_link: ExpressionsChain | None = None, + sub_link: ExpressionsChain | None = None, + fieldname: str | None = None, + association=None, + subtype=None, + ) -> None: self.type = type - self.left_link: Optional[ExpressionsChain] = left_link - self.right_link: Optional[ExpressionsChain] = right_link - self.sub_link: Optional[ExpressionsChain] = sub_link - self.fieldname: Optional[str] = fieldname - self.association: Optional[LanguageGraphAssociation] = association - self.subtype: Optional[Any] = subtype - + self.left_link: ExpressionsChain | None = left_link + self.right_link: ExpressionsChain | None = right_link + self.sub_link: ExpressionsChain | None = sub_link + self.fieldname: str | None = fieldname + self.association: LanguageGraphAssociation | None = association + self.subtype: Any | None = subtype def to_dict(self) -> dict: - """Convert ExpressionsChain to dictionary""" - match (self.type): + """Convert ExpressionsChain to dictionary.""" + match self.type: case 'union' | 'intersection' | 'difference' | 'collect': return { self.type: { - 'left': self.left_link.to_dict() - if self.left_link else {}, - 'right': self.right_link.to_dict() - if self.right_link else {} + 'left': self.left_link.to_dict() if self.left_link else {}, + 'right': self.right_link.to_dict() if self.right_link else {}, }, - 'type': self.type + 'type': self.type, } case 'field': if not self.association: - raise LanguageGraphAssociationError( - "Missing association for expressions chain" - ) + msg = 'Missing association for expressions chain' + raise LanguageGraphAssociationError(msg) if self.fieldname == self.association.left_field.fieldname: asset_type = self.association.left_field.asset.name elif self.fieldname == self.association.right_field.fieldname: asset_type = self.association.right_field.asset.name else: - raise LanguageGraphException( - 'Failed to find fieldname "%s" in association:\n%s' % - ( - self.fieldname, - json.dumps(self.association.to_dict(), - indent = 2) - ) - ) + msg = f'Failed to find fieldname "{self.fieldname}" in association:\n{json.dumps(self.association.to_dict(), indent=2)}' + raise LanguageGraphException(msg) return { - self.association.name: - { + self.association.name: { 'fieldname': self.fieldname, - 'asset type': asset_type + 'asset type': asset_type, }, - 'type': self.type + 'type': self.type, } case 'transitive': if not self.sub_link: - raise LanguageGraphException( - "No sub link for transitive expressions chain" - ) - return { - 'transitive': self.sub_link.to_dict(), - 'type': self.type - } + msg = 'No sub link for transitive expressions chain' + raise LanguageGraphException(msg) + return {'transitive': self.sub_link.to_dict(), 'type': self.type} case 'subType': if not self.subtype: - raise LanguageGraphException( - "No subtype for expressions chain" - ) + msg = 'No subtype for expressions chain' + raise LanguageGraphException(msg) if not self.sub_link: - raise LanguageGraphException( - "No sub link for subtype expressions chain" - ) + msg = 'No sub link for subtype expressions chain' + raise LanguageGraphException(msg) return { 'subType': self.subtype.name, 'expression': self.sub_link.to_dict(), - 'type': self.type + 'type': self.type, } case _: @@ -494,88 +429,72 @@ def to_dict(self) -> dict: raise LanguageGraphAssociationError(msg % self.type) @classmethod - def _from_dict(cls, - serialized_expr_chain: dict, - lang_graph: LanguageGraph, - ) -> Optional[ExpressionsChain]: + def _from_dict( + cls, + serialized_expr_chain: dict, + lang_graph: LanguageGraph, + ) -> ExpressionsChain | None: """Create ExpressionsChain from dict Args: serialized_expr_chain - expressions chain in dict format lang_graph - the LanguageGraph that contains the assets, associations, and attack steps relevant for - the expressions chain + the expressions chain. """ - if serialized_expr_chain is None or not serialized_expr_chain: return None if 'type' not in serialized_expr_chain: - logger.debug(json.dumps(serialized_expr_chain, indent = 2)) + logger.debug(json.dumps(serialized_expr_chain, indent=2)) msg = 'Missing expressions chain type!' logger.error(msg) raise LanguageGraphAssociationError(msg) return None expr_chain_type = serialized_expr_chain['type'] - match (expr_chain_type): + match expr_chain_type: case 'union' | 'intersection' | 'difference' | 'collect': left_link = cls._from_dict( - serialized_expr_chain[expr_chain_type]['left'], - lang_graph + serialized_expr_chain[expr_chain_type]['left'], lang_graph ) right_link = cls._from_dict( - serialized_expr_chain[expr_chain_type]['right'], - lang_graph + serialized_expr_chain[expr_chain_type]['right'], lang_graph ) - new_expr_chain = ExpressionsChain( - type = expr_chain_type, - left_link = left_link, - right_link = right_link + return ExpressionsChain( + type=expr_chain_type, left_link=left_link, right_link=right_link ) - return new_expr_chain case 'field': - assoc_name = list(serialized_expr_chain.keys())[0] - target_asset = lang_graph.assets[\ - serialized_expr_chain[assoc_name]['asset type']] + assoc_name = next(iter(serialized_expr_chain.keys())) + target_asset = lang_graph.assets[ + serialized_expr_chain[assoc_name]['asset type'] + ] fieldname = serialized_expr_chain[assoc_name]['fieldname'] association = None for assoc in target_asset.associations.values(): - if assoc.contains_fieldname(fieldname) and \ - assoc.name == assoc_name: + if assoc.contains_fieldname(fieldname) and assoc.name == assoc_name: association = assoc break if association is None: - msg = 'Failed to find association "%s" with '\ - 'fieldname "%s"' + msg = 'Failed to find association "%s" with fieldname "%s"' logger.error(msg % (assoc_name, fieldname)) - raise LanguageGraphException(msg % (assoc_name, - fieldname)) + raise LanguageGraphException(msg % (assoc_name, fieldname)) - new_expr_chain = ExpressionsChain( - type = 'field', - association = association, - fieldname = fieldname + return ExpressionsChain( + type='field', association=association, fieldname=fieldname ) - return new_expr_chain case 'transitive': sub_link = cls._from_dict( - serialized_expr_chain['transitive'], - lang_graph - ) - new_expr_chain = ExpressionsChain( - type = 'transitive', - sub_link = sub_link + serialized_expr_chain['transitive'], lang_graph ) - return new_expr_chain + return ExpressionsChain(type='transitive', sub_link=sub_link) case 'subType': sub_link = cls._from_dict( - serialized_expr_chain['expression'], - lang_graph + serialized_expr_chain['expression'], lang_graph ) subtype_name = serialized_expr_chain['subType'] if subtype_name in lang_graph.assets: @@ -585,70 +504,60 @@ def _from_dict(cls, logger.error(msg % subtype_name) raise LanguageGraphException(msg % subtype_name) - new_expr_chain = ExpressionsChain( - type = 'subType', - sub_link = sub_link, - subtype = subtype_asset + return ExpressionsChain( + type='subType', sub_link=sub_link, subtype=subtype_asset ) - return new_expr_chain case _: msg = 'Unknown expressions chain type %s!' logger.error(msg, serialized_expr_chain['type']) - raise LanguageGraphAssociationError(msg % - serialized_expr_chain['type']) - + raise LanguageGraphAssociationError(msg % serialized_expr_chain['type']) def __repr__(self) -> str: return str(self.to_dict()) -class LanguageGraph(): - """Graph representation of a MAL language""" - def __init__(self, lang: Optional[dict] = None): +class LanguageGraph: + """Graph representation of a MAL language.""" + + def __init__(self, lang: dict | None = None) -> None: self.assets: dict = {} if lang is not None: self._lang_spec: dict = lang self.metadata = { - "version": lang["defines"]["version"], - "id": lang["defines"]["id"], + 'version': lang['defines']['version'], + 'id': lang['defines']['id'], } self._generate_graph() - @classmethod def from_mal_spec(cls, mal_spec_file: str) -> LanguageGraph: - """ - Create a LanguageGraph from a .mal file (a MAL spec). + """Create a LanguageGraph from a .mal file (a MAL spec). Arguments: mal_spec_file - the path to the .mal file + """ - logger.info("Loading mal spec %s", mal_spec_file) + logger.info('Loading mal spec %s', mal_spec_file) return LanguageGraph(MalCompiler().compile(mal_spec_file)) - @classmethod def from_mar_archive(cls, mar_archive: str) -> LanguageGraph: - """ - Create a LanguageGraph from a ".mar" archive provided by malc + """Create a LanguageGraph from a ".mar" archive provided by malc (https://github.com/mal-lang/malc). Arguments: mar_archive - the path to a ".mar" archive + """ logger.info('Loading mar archive %s', mar_archive) with zipfile.ZipFile(mar_archive, 'r') as archive: langspec = archive.read('langspec.json') return LanguageGraph(json.loads(langspec)) - def _to_dict(self): - """Converts LanguageGraph into a dict""" - - logger.debug( - 'Serializing %s assets.', len(self.assets.items()) - ) + """Converts LanguageGraph into a dict.""" + logger.debug('Serializing %s assets.', len(self.assets.items())) serialized_graph = {} for asset in self.assets.values(): @@ -656,42 +565,41 @@ def _to_dict(self): return serialized_graph - def _link_association_to_assets(cls, - assoc: LanguageGraphAssociation, - left_asset: LanguageGraphAsset, - right_asset: LanguageGraphAsset): + def _link_association_to_assets( + self, + assoc: LanguageGraphAssociation, + left_asset: LanguageGraphAsset, + right_asset: LanguageGraphAsset, + ) -> None: left_asset.own_associations[assoc.right_field.fieldname] = assoc right_asset.own_associations[assoc.left_field.fieldname] = assoc def save_to_file(self, filename: str) -> None: - """Save to json/yml depending on extension""" + """Save to json/yml depending on extension.""" return save_dict_to_file(filename, self._to_dict()) - @classmethod def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: """Create LanguageGraph from dict Args: - serialized_graph - LanguageGraph in dict format + serialized_graph - LanguageGraph in dict format. """ - logger.debug('Create language graph from dictionary.') lang_graph = LanguageGraph() # Recreate all of the assets for asset_dict in serialized_graph.values(): logger.debug( - 'Create asset language graph nodes for asset %s', - asset_dict['name'] + 'Create asset language graph nodes for asset %s', asset_dict['name'] ) asset_node = LanguageGraphAsset( - name = asset_dict['name'], - own_associations = {}, - attack_steps = {}, - info = asset_dict['info'], - own_super_asset = None, - own_sub_assets = set(), - own_variables = {}, - is_abstract = asset_dict['is_abstract'] + name=asset_dict['name'], + own_associations={}, + attack_steps={}, + info=asset_dict['info'], + own_super_asset=None, + own_sub_assets=set(), + own_variables={}, + is_abstract=asset_dict['is_abstract'], ) lang_graph.assets[asset_dict['name']] = asset_node @@ -705,10 +613,10 @@ def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: super_asset = lang_graph.assets[super_asset_name] if not super_asset: msg = 'Failed to find super asset "%s" for asset "%s"!' - logger.error( - msg, asset_dict["super_asset"], asset_dict["name"]) + logger.error(msg, asset_dict['super_asset'], asset_dict['name']) raise LanguageGraphSuperAssetNotFoundError( - msg % (asset_dict["super_asset"], asset_dict["name"])) + msg % (asset_dict['super_asset'], asset_dict['name']) + ) super_asset.own_sub_assets.add(asset) asset.own_super_asset = super_asset @@ -717,7 +625,7 @@ def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: for asset_dict in serialized_graph.values(): logger.debug( 'Create association language graph nodes for asset %s', - asset_dict['name'] + asset_dict['name'], ) asset = lang_graph.assets[asset_dict['name']] @@ -725,42 +633,42 @@ def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: left_asset = lang_graph.assets[association['left']['asset']] if not left_asset: msg = 'Left asset "%s" for association "%s" not found!' - logger.error( - msg, association['left']['asset'], - association['name']) + logger.error(msg, association['left']['asset'], association['name']) raise LanguageGraphAssociationError( - msg % (association['left']['asset'], - association['name'])) + msg % (association['left']['asset'], association['name']) + ) right_asset = lang_graph.assets[association['right']['asset']] if not right_asset: msg = 'Right asset "%s" for association "%s" not found!' logger.error( - msg, association['right']['asset'], - association['name']) + msg, association['right']['asset'], association['name'] + ) raise LanguageGraphAssociationError( - msg % (association['right']['asset'], - association['name']) + msg % (association['right']['asset'], association['name']) ) assoc_node = LanguageGraphAssociation( - name = association['name'], - left_field = LanguageGraphAssociationField( + name=association['name'], + left_field=LanguageGraphAssociationField( left_asset, association['left']['fieldname'], association['left']['min'], - association['left']['max']), - right_field = LanguageGraphAssociationField( + association['left']['max'], + ), + right_field=LanguageGraphAssociationField( right_asset, association['right']['fieldname'], association['right']['min'], - association['right']['max']), - info = association['info'] + association['right']['max'], + ), + info=association['info'], ) # Add the association to the left and right asset - lang_graph._link_association_to_assets(assoc_node, - left_asset, right_asset) + lang_graph._link_association_to_assets( + assoc_node, left_asset, right_asset + ) # Recreate the variables for asset_dict in serialized_graph.values(): @@ -768,10 +676,7 @@ def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: for variable_name, var_target in asset_dict['variables'].items(): (target_asset_name, expr_chain_dict) = var_target target_asset = lang_graph.assets[target_asset_name] - expr_chain = ExpressionsChain._from_dict( - expr_chain_dict, - lang_graph - ) + expr_chain = ExpressionsChain._from_dict(expr_chain_dict, lang_graph) asset.own_variables[variable_name] = (target_asset, expr_chain) # Recreate the attack steps @@ -779,40 +684,40 @@ def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: asset = lang_graph.assets[asset_dict['name']] logger.debug( 'Create attack steps language graph nodes for asset %s', - asset_dict['name'] + asset_dict['name'], ) for attack_step_dict in asset_dict['attack_steps'].values(): attack_step_node = LanguageGraphAttackStep( - name = attack_step_dict['name'], - type = attack_step_dict['type'], - asset = asset, - ttc = attack_step_dict['ttc'], - overrides = attack_step_dict['overrides'], - children = {}, - parents = {}, - info = attack_step_dict['info'], - tags = set(attack_step_dict['tags']) + name=attack_step_dict['name'], + type=attack_step_dict['type'], + asset=asset, + ttc=attack_step_dict['ttc'], + overrides=attack_step_dict['overrides'], + children={}, + parents={}, + info=attack_step_dict['info'], + tags=set(attack_step_dict['tags']), ) - asset.attack_steps[attack_step_dict['name']] = \ - attack_step_node + asset.attack_steps[attack_step_dict['name']] = attack_step_node # Relink attack steps based on inheritence for asset_dict in serialized_graph.values(): asset = lang_graph.assets[asset_dict['name']] for attack_step_dict in asset_dict['attack_steps'].values(): - if 'inherits' in attack_step_dict and \ - attack_step_dict['inherits'] is not None: - attack_step = asset.attack_steps[ - attack_step_dict['name']] - ancestor_asset_name, ancestor_attack_step_name = \ - disaggregate_attack_step_full_name( - attack_step_dict['inherits']) + if ( + 'inherits' in attack_step_dict + and attack_step_dict['inherits'] is not None + ): + attack_step = asset.attack_steps[attack_step_dict['name']] + ancestor_asset_name, ancestor_attack_step_name = ( + disaggregate_attack_step_full_name(attack_step_dict['inherits']) + ) ancestor_asset = lang_graph.assets[ancestor_asset_name] - ancestor_attack_step = ancestor_asset.attack_steps[\ - ancestor_attack_step_name] + ancestor_attack_step = ancestor_asset.attack_steps[ + ancestor_attack_step_name + ] attack_step.inherits = ancestor_attack_step - # Relink attack steps based on expressions chains for asset_dict in serialized_graph.values(): asset = lang_graph.assets[asset_dict['name']] @@ -821,64 +726,65 @@ def _from_dict(cls, serialized_graph: dict) -> LanguageGraph: for child_target in attack_step_dict['children'].items(): target_full_attack_step_name = child_target[0] expr_chains = child_target[1] - target_asset_name, target_attack_step_name = \ - disaggregate_attack_step_full_name( - target_full_attack_step_name) + target_asset_name, target_attack_step_name = ( + disaggregate_attack_step_full_name(target_full_attack_step_name) + ) target_asset = lang_graph.assets[target_asset_name] target_attack_step = target_asset.attack_steps[ - target_attack_step_name] + target_attack_step_name + ] for expr_chain_dict in expr_chains: expr_chain = ExpressionsChain._from_dict( - expr_chain_dict, - lang_graph + expr_chain_dict, lang_graph ) if target_attack_step.full_name in attack_step.children: - attack_step.children[target_attack_step.full_name].\ - append((target_attack_step, expr_chain)) + attack_step.children[target_attack_step.full_name].append( + (target_attack_step, expr_chain) + ) else: - attack_step.children[target_attack_step.full_name] = \ - [(target_attack_step, expr_chain)] + attack_step.children[target_attack_step.full_name] = [ + (target_attack_step, expr_chain) + ] for parent_target in attack_step_dict['parents'].items(): target_full_attack_step_name = parent_target[0] expr_chains = parent_target[1] - target_asset_name, target_attack_step_name = \ - disaggregate_attack_step_full_name( - target_full_attack_step_name) + target_asset_name, target_attack_step_name = ( + disaggregate_attack_step_full_name(target_full_attack_step_name) + ) target_asset = lang_graph.assets[target_asset_name] target_attack_step = target_asset.attack_steps[ - target_attack_step_name] + target_attack_step_name + ] for expr_chain_dict in expr_chains: expr_chain = ExpressionsChain._from_dict( - expr_chain_dict, - lang_graph + expr_chain_dict, lang_graph ) if target_attack_step.full_name in attack_step.parents: - attack_step.parents[target_attack_step.full_name].\ - append((target_attack_step, expr_chain)) + attack_step.parents[target_attack_step.full_name].append( + (target_attack_step, expr_chain) + ) else: - attack_step.parents[target_attack_step.full_name] = \ - [(target_attack_step, expr_chain)] + attack_step.parents[target_attack_step.full_name] = [ + (target_attack_step, expr_chain) + ] # Recreate the requirements of exist and notExist attack steps - if attack_step.type == 'exist' or \ - attack_step.type == 'notExist': + if attack_step.type in {'exist', 'notExist'}: if 'requires' in attack_step_dict: expr_chains = attack_step_dict['requires'] attack_step.own_requires = [] for expr_chain_dict in expr_chains: expr_chain = ExpressionsChain._from_dict( - expr_chain_dict, - lang_graph + expr_chain_dict, lang_graph ) attack_step.own_requires.append(expr_chain) return lang_graph - @classmethod def load_from_file(cls, filename: str) -> LanguageGraph: - """Create LanguageGraph from mal, mar, yaml or json""" + """Create LanguageGraph from mal, mar, yaml or json.""" lang_graph = None if filename.endswith('.mal'): lang_graph = cls.from_mal_spec(filename) @@ -886,42 +792,33 @@ def load_from_file(cls, filename: str) -> LanguageGraph: lang_graph = cls.from_mar_archive(filename) elif filename.endswith(('.yaml', '.yml')): lang_graph = cls._from_dict(load_dict_from_yaml_file(filename)) - elif filename.endswith(('.json')): + elif filename.endswith('.json'): lang_graph = cls._from_dict(load_dict_from_json_file(filename)) else: - raise TypeError( - "Unknown file extension, expected json/mal/mar/yml/yaml" - ) + msg = 'Unknown file extension, expected json/mal/mar/yml/yaml' + raise TypeError(msg) if lang_graph: return lang_graph - else: - raise LanguageGraphException( - f'Failed to load language graph from file "{filename}".' - ) - - + msg = f'Failed to load language graph from file "{filename}".' + raise LanguageGraphException(msg) def save_language_specification_to_json(self, filename: str) -> None: - """ - Save a MAL language specification dictionary to a JSON file + """Save a MAL language specification dictionary to a JSON file. Arguments: filename - the JSON filename where the language specification will be written + """ logger.info('Save language specification to %s', filename) with open(filename, 'w', encoding='utf-8') as file: json.dump(self._lang_spec, file, indent=4) - - def process_step_expression(self, - target_asset, - expr_chain, - step_expression: dict - ) -> tuple: - """ - Recursively process an attack step expression. + def process_step_expression( + self, target_asset, expr_chain, step_expression: dict + ) -> tuple: + """Recursively process an attack step expression. Arguments: target_asset - The asset type that this step expression should @@ -939,79 +836,61 @@ def process_step_expression(self, Return: A tuple triplet containing the target asset, the resulting parent associations chain, and the name of the attack step. - """ + """ if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug logger.debug( - 'Processing Step Expression:\n%s', - json.dumps(step_expression, indent = 2) + 'Processing Step Expression:\n%s', json.dumps(step_expression, indent=2) ) - match (step_expression['type']): + match step_expression['type']: case 'attackStep': # The attack step expression just adds the name of the attack # step. All other step expressions only modify the target # asset and parent associations chain. - return ( - target_asset, - None, - step_expression['name'] - ) + return (target_asset, None, step_expression['name']) case 'union' | 'intersection' | 'difference': # The set operators are used to combine the left hand and right # hand targets accordingly. lh_target_asset, lh_expr_chain, _ = self.process_step_expression( - target_asset, - expr_chain, - step_expression['lhs'] + target_asset, expr_chain, step_expression['lhs'] + ) + rh_target_asset, rh_expr_chain, _ = self.process_step_expression( + target_asset, expr_chain, step_expression['rhs'] ) - rh_target_asset, rh_expr_chain, _ = \ - self.process_step_expression( - target_asset, - expr_chain, - step_expression['rhs'] - ) - if not lh_target_asset.get_all_common_superassets( - rh_target_asset): + if not lh_target_asset.get_all_common_superassets(rh_target_asset): logger.error( - "Set operation attempted between targets that" - " do not share any common superassets: %s and %s!", - lh_target_asset.name, rh_target_asset.name + 'Set operation attempted between targets that' + ' do not share any common superassets: %s and %s!', + lh_target_asset.name, + rh_target_asset.name, ) return (None, None, None) new_expr_chain = ExpressionsChain( - type = step_expression['type'], - left_link = lh_expr_chain, - right_link = rh_expr_chain - ) - return ( - lh_target_asset, - new_expr_chain, - None + type=step_expression['type'], + left_link=lh_expr_chain, + right_link=rh_expr_chain, ) + return (lh_target_asset, new_expr_chain, None) case 'variable': var_name = step_expression['name'] var_target_asset, var_expr_chain = self._resolve_variable( - target_asset, var_name) - var_target_asset, var_expr_chain = \ - target_asset.get_variable(var_name) + target_asset, var_name + ) + var_target_asset, var_expr_chain = target_asset.get_variable(var_name) if var_expr_chain is not None: - return ( - var_target_asset, - var_expr_chain, - None - ) - else: - logger.error( - 'Failed to find variable \"%s\" for %s', - step_expression["name"], target_asset.name - ) - return (None, None, None) + return (var_target_asset, var_expr_chain, None) + logger.error( + 'Failed to find variable "%s" for %s', + step_expression['name'], + target_asset.name, + ) + return (None, None, None) case 'field': # Change the target asset from the current one to the associated @@ -1019,74 +898,58 @@ def process_step_expression(self, # fieldname and association to the parent associations chain. fieldname = step_expression['name'] if not target_asset: - logger.error( - 'Missing target asset for field "%s"!', fieldname - ) + logger.error('Missing target asset for field "%s"!', fieldname) return (None, None, None) new_target_asset = None for association in target_asset.associations.values(): - if (association.left_field.fieldname == fieldname and \ - target_asset.is_subasset_of( - association.right_field.asset)): + if ( + association.left_field.fieldname == fieldname + and target_asset.is_subasset_of(association.right_field.asset) + ): new_target_asset = association.left_field.asset - if (association.right_field.fieldname == fieldname and \ - target_asset.is_subasset_of( - association.left_field.asset)): + if ( + association.right_field.fieldname == fieldname + and target_asset.is_subasset_of(association.left_field.asset) + ): new_target_asset = association.right_field.asset if new_target_asset: new_expr_chain = ExpressionsChain( - type = 'field', - fieldname = fieldname, - association = association - ) - return ( - new_target_asset, - new_expr_chain, - None + type='field', fieldname=fieldname, association=association ) + return (new_target_asset, new_expr_chain, None) logger.error( 'Failed to find field "%s" on asset "%s"!', - fieldname, target_asset.name + fieldname, + target_asset.name, ) return (None, None, None) case 'transitive': # Create a transitive tuple entry that applies to the next # component of the step expression. - result_target_asset, \ - result_expr_chain, \ - attack_step = \ + result_target_asset, result_expr_chain, attack_step = ( self.process_step_expression( - target_asset, - expr_chain, - step_expression['stepExpression'] + target_asset, expr_chain, step_expression['stepExpression'] ) - new_expr_chain = ExpressionsChain( - type = 'transitive', - sub_link = result_expr_chain ) - return ( - result_target_asset, - new_expr_chain, - attack_step + new_expr_chain = ExpressionsChain( + type='transitive', sub_link=result_expr_chain ) + return (result_target_asset, new_expr_chain, attack_step) case 'subType': # Create a subType tuple entry that applies to the next # component of the step expression and changes the target # asset to the subasset. subtype_name = step_expression['subType'] - result_target_asset, \ - result_expr_chain, \ - attack_step = \ + result_target_asset, result_expr_chain, attack_step = ( self.process_step_expression( - target_asset, - expr_chain, - step_expression['stepExpression'] + target_asset, expr_chain, step_expression['stepExpression'] ) + ) if subtype_name in self.assets: subtype_asset = self.assets[subtype_name] @@ -1099,67 +962,48 @@ def process_step_expression(self, logger.error( 'Found subtype "%s" which does not extend "%s", ' 'therefore the subtype cannot be resolved.', - subtype_name, result_target_asset.name + subtype_name, + result_target_asset.name, ) return (None, None, None) new_expr_chain = ExpressionsChain( - type = 'subType', - sub_link = result_expr_chain, - subtype = subtype_asset - ) - return ( - subtype_asset, - new_expr_chain, - attack_step + type='subType', sub_link=result_expr_chain, subtype=subtype_asset ) + return (subtype_asset, new_expr_chain, attack_step) case 'collect': # Apply the right hand step expression to left hand step # expression target asset and parent associations chain. - (lh_target_asset, lh_expr_chain, _) = \ - self.process_step_expression( - target_asset, - expr_chain, - step_expression['lhs'] - ) - (rh_target_asset, - rh_expr_chain, - rh_attack_step_name) = \ + (lh_target_asset, lh_expr_chain, _) = self.process_step_expression( + target_asset, expr_chain, step_expression['lhs'] + ) + (rh_target_asset, rh_expr_chain, rh_attack_step_name) = ( self.process_step_expression( - lh_target_asset, - None, - step_expression['rhs'] + lh_target_asset, None, step_expression['rhs'] ) + ) if rh_expr_chain: new_expr_chain = ExpressionsChain( - type = 'collect', - left_link = lh_expr_chain, - right_link = rh_expr_chain + type='collect', + left_link=lh_expr_chain, + right_link=rh_expr_chain, ) else: new_expr_chain = lh_expr_chain - return ( - rh_target_asset, - new_expr_chain, - rh_attack_step_name - ) + return (rh_target_asset, new_expr_chain, rh_attack_step_name) case _: - logger.error( - 'Unknown attack step type: "%s"', step_expression["type"] - ) + logger.error('Unknown attack step type: "%s"', step_expression['type']) return (None, None, None) - def reverse_expr_chain( - self, - expr_chain: Optional[ExpressionsChain], - reverse_chain: Optional[ExpressionsChain] - ) -> Optional[ExpressionsChain]: - """ - Recursively reverse the associations chain. From parent to child or + self, + expr_chain: ExpressionsChain | None, + reverse_chain: ExpressionsChain | None, + ) -> ExpressionsChain | None: + """Recursively reverse the associations chain. From parent to child or vice versa. Arguments: @@ -1171,84 +1015,76 @@ def reverse_expr_chain( Return: The resulting reversed associations chain. + """ if not expr_chain: return reverse_chain - else: - match (expr_chain.type): - case 'union' | 'intersection' | 'difference' | 'collect': - left_reverse_chain = \ - self.reverse_expr_chain(expr_chain.left_link, - reverse_chain) - right_reverse_chain = \ - self.reverse_expr_chain(expr_chain.right_link, - reverse_chain) - if expr_chain.type == 'collect': - new_expr_chain = ExpressionsChain( - type = expr_chain.type, - left_link = right_reverse_chain, - right_link = left_reverse_chain - ) - else: - new_expr_chain = ExpressionsChain( - type = expr_chain.type, - left_link = left_reverse_chain, - right_link = right_reverse_chain - ) - - return new_expr_chain - - case 'transitive': - result_reverse_chain = self.reverse_expr_chain( - expr_chain.sub_link, reverse_chain) + match expr_chain.type: + case 'union' | 'intersection' | 'difference' | 'collect': + left_reverse_chain = self.reverse_expr_chain( + expr_chain.left_link, reverse_chain + ) + right_reverse_chain = self.reverse_expr_chain( + expr_chain.right_link, reverse_chain + ) + if expr_chain.type == 'collect': + new_expr_chain = ExpressionsChain( + type=expr_chain.type, + left_link=right_reverse_chain, + right_link=left_reverse_chain, + ) + else: new_expr_chain = ExpressionsChain( - type = 'transitive', - sub_link = result_reverse_chain + type=expr_chain.type, + left_link=left_reverse_chain, + right_link=right_reverse_chain, ) - return new_expr_chain - case 'field': - association = expr_chain.association + return new_expr_chain - if not association: - raise LanguageGraphException( - "Missing association for expressions chain" - ) + case 'transitive': + result_reverse_chain = self.reverse_expr_chain( + expr_chain.sub_link, reverse_chain + ) + return ExpressionsChain( + type='transitive', sub_link=result_reverse_chain + ) - if not expr_chain.fieldname: - raise LanguageGraphException( - "Missing field name for expressions chain" - ) + case 'field': + association = expr_chain.association - opposite_fieldname = association.get_opposite_fieldname( - expr_chain.fieldname) - new_expr_chain = ExpressionsChain( - type = 'field', - association = association, - fieldname = opposite_fieldname - ) - return new_expr_chain + if not association: + msg = 'Missing association for expressions chain' + raise LanguageGraphException(msg) - case 'subType': - result_reverse_chain = self.reverse_expr_chain( - expr_chain.sub_link, - reverse_chain - ) - new_expr_chain = ExpressionsChain( - type = 'subType', - sub_link = result_reverse_chain, - subtype = expr_chain.subtype - ) - return new_expr_chain + if not expr_chain.fieldname: + msg = 'Missing field name for expressions chain' + raise LanguageGraphException(msg) - case _: - msg = 'Unknown assoc chain element "%s"' - logger.error(msg, expr_chain.type) - raise LanguageGraphAssociationError(msg % expr_chain.type) + opposite_fieldname = association.get_opposite_fieldname( + expr_chain.fieldname + ) + return ExpressionsChain( + type='field', association=association, fieldname=opposite_fieldname + ) + + case 'subType': + result_reverse_chain = self.reverse_expr_chain( + expr_chain.sub_link, reverse_chain + ) + return ExpressionsChain( + type='subType', + sub_link=result_reverse_chain, + subtype=expr_chain.subtype, + ) + + case _: + msg = 'Unknown assoc chain element "%s"' + logger.error(msg, expr_chain.type) + raise LanguageGraphAssociationError(msg % expr_chain.type) def _resolve_variable(self, asset, var_name) -> tuple: - """ - Resolve a variable for a specific asset by variable name. + """Resolve a variable for a specific asset by variable name. Arguments: asset - a language graph asset to which the variable belongs @@ -1257,39 +1093,35 @@ def _resolve_variable(self, asset, var_name) -> tuple: Return: A tuple containing the target asset and expressions chain required to reach it. + """ if var_name not in asset.variables: var_expr = self._get_var_expr_for_asset(asset.name, var_name) target_asset, expr_chain, _ = self.process_step_expression( - asset, - None, - var_expr + asset, None, var_expr ) asset.own_variables[var_name] = (target_asset, expr_chain) return (target_asset, expr_chain) return asset.variables[var_name] - def _generate_graph(self) -> None: - """ - Generate language graph starting from the MAL language specification + """Generate language graph starting from the MAL language specification given in the constructor. """ # Generate all of the asset nodes of the language graph. for asset_dict in self._lang_spec['assets']: logger.debug( - 'Create asset language graph nodes for asset %s', - asset_dict['name'] + 'Create asset language graph nodes for asset %s', asset_dict['name'] ) asset_node = LanguageGraphAsset( - name = asset_dict['name'], - own_associations = {}, - attack_steps = {}, - info = asset_dict['meta'], - own_super_asset = None, - own_sub_assets = set(), - own_variables = {}, - is_abstract = asset_dict['isAbstract'] + name=asset_dict['name'], + own_associations={}, + attack_steps={}, + info=asset_dict['meta'], + own_super_asset=None, + own_sub_assets=set(), + own_variables={}, + is_abstract=asset_dict['isAbstract'], ) self.assets[asset_dict['name']] = asset_node @@ -1300,10 +1132,10 @@ def _generate_graph(self) -> None: super_asset = self.assets[asset_dict['superAsset']] if not super_asset: msg = 'Failed to find super asset "%s" for asset "%s"!' - logger.error( - msg, asset_dict["superAsset"], asset_dict["name"]) + logger.error(msg, asset_dict['superAsset'], asset_dict['name']) raise LanguageGraphSuperAssetNotFoundError( - msg % (asset_dict["superAsset"], asset_dict["name"])) + msg % (asset_dict['superAsset'], asset_dict['name']) + ) super_asset.own_sub_assets.add(asset) asset.own_super_asset = super_asset @@ -1311,8 +1143,7 @@ def _generate_graph(self) -> None: # Generate all of the association nodes of the language graph. for asset in self.assets.values(): logger.debug( - 'Create association language graph nodes for asset %s', - asset.name + 'Create association language graph nodes for asset %s', asset.name ) associations = self._get_associations_for_asset_type(asset.name) @@ -1320,38 +1151,38 @@ def _generate_graph(self) -> None: left_asset = self.assets[association['leftAsset']] if not left_asset: msg = 'Left asset "%s" for association "%s" not found!' - logger.error( - msg, association["leftAsset"], association["name"]) + logger.error(msg, association['leftAsset'], association['name']) raise LanguageGraphAssociationError( - msg % (association["leftAsset"], association["name"])) + msg % (association['leftAsset'], association['name']) + ) right_asset = self.assets[association['rightAsset']] if not right_asset: msg = 'Right asset "%s" for association "%s" not found!' - logger.error( - msg, association["rightAsset"], association["name"]) + logger.error(msg, association['rightAsset'], association['name']) raise LanguageGraphAssociationError( - msg % (association["rightAsset"], association["name"]) + msg % (association['rightAsset'], association['name']) ) assoc_node = LanguageGraphAssociation( - name = association['name'], - left_field = LanguageGraphAssociationField( + name=association['name'], + left_field=LanguageGraphAssociationField( left_asset, association['leftField'], association['leftMultiplicity']['min'], - association['leftMultiplicity']['max']), - right_field = LanguageGraphAssociationField( + association['leftMultiplicity']['max'], + ), + right_field=LanguageGraphAssociationField( right_asset, association['rightField'], association['rightMultiplicity']['min'], - association['rightMultiplicity']['max']), - info = association['meta'] + association['rightMultiplicity']['max'], + ), + info=association['meta'], ) # Add the association to the left and right asset - self._link_association_to_assets(assoc_node, - left_asset, right_asset) + self._link_association_to_assets(assoc_node, left_asset, right_asset) # Set the variables for asset in self.assets.values(): @@ -1360,39 +1191,37 @@ def _generate_graph(self) -> None: # Avoid running json.dumps when not in debug logger.debug( 'Processing Variable Expression:\n%s', - json.dumps(variable, indent = 2) + json.dumps(variable, indent=2), ) self._resolve_variable(asset, variable['name']) - # Generate all of the attack step nodes of the language graph. for asset in self.assets.values(): logger.debug( - 'Create attack steps language graph nodes for asset %s', - asset.name + 'Create attack steps language graph nodes for asset %s', asset.name ) attack_steps = self._get_attacks_for_asset_type(asset.name) for attack_step_attribs in attack_steps.values(): logger.debug( 'Create attack step language graph nodes for %s', - attack_step_attribs['name'] + attack_step_attribs['name'], ) attack_step_node = LanguageGraphAttackStep( - name = attack_step_attribs['name'], - type = attack_step_attribs['type'], - asset = asset, - ttc = attack_step_attribs['ttc'], - overrides = attack_step_attribs['reaches']['overrides'] \ - if attack_step_attribs['reaches'] else False, - children = {}, - parents = {}, - info = attack_step_attribs['meta'], - tags = set(attack_step_attribs['tags']) + name=attack_step_attribs['name'], + type=attack_step_attribs['type'], + asset=asset, + ttc=attack_step_attribs['ttc'], + overrides=attack_step_attribs['reaches']['overrides'] + if attack_step_attribs['reaches'] + else False, + children={}, + parents={}, + info=attack_step_attribs['meta'], + tags=set(attack_step_attribs['tags']), ) attack_step_node._attributes = attack_step_attribs - asset.attack_steps[attack_step_attribs['name']] = \ - attack_step_node + asset.attack_steps[attack_step_attribs['name']] = attack_step_node # Create the inherited attack steps assets = list(self.assets.values()) @@ -1402,137 +1231,136 @@ def _generate_graph(self) -> None: # The asset still has super assets that should be resolved # first, moved it to the back. assets.append(asset) - else: - if asset.own_super_asset: - for attack_step in \ - asset.own_super_asset.attack_steps.values(): - if attack_step.name not in asset.attack_steps: - attack_step_node = LanguageGraphAttackStep( - name = attack_step.name, - type = attack_step.type, - asset = asset, - ttc = attack_step.ttc, - overrides = False, - children = {}, - parents = {}, - info = attack_step.info, - tags = set(attack_step.tags) - ) - attack_step_node.inherits = attack_step - asset.attack_steps[attack_step.name] = attack_step_node - elif asset.attack_steps[attack_step.name].overrides: - # The inherited attack step was already overridden. - continue - else: - asset.attack_steps[attack_step.name].inherits = \ - attack_step - asset.attack_steps[attack_step.name].tags |= \ - attack_step.tags - asset.attack_steps[attack_step.name].info |= \ - attack_step.info + elif asset.own_super_asset: + for attack_step in asset.own_super_asset.attack_steps.values(): + if attack_step.name not in asset.attack_steps: + attack_step_node = LanguageGraphAttackStep( + name=attack_step.name, + type=attack_step.type, + asset=asset, + ttc=attack_step.ttc, + overrides=False, + children={}, + parents={}, + info=attack_step.info, + tags=set(attack_step.tags), + ) + attack_step_node.inherits = attack_step + asset.attack_steps[attack_step.name] = attack_step_node + elif asset.attack_steps[attack_step.name].overrides: + # The inherited attack step was already overridden. + continue + else: + asset.attack_steps[attack_step.name].inherits = attack_step + asset.attack_steps[attack_step.name].tags |= attack_step.tags + asset.attack_steps[attack_step.name].info |= attack_step.info # Then, link all of the attack step nodes according to their # associations. for asset in self.assets.values(): for attack_step in asset.attack_steps.values(): logger.debug( - 'Determining children for attack step %s', - attack_step.name + 'Determining children for attack step %s', attack_step.name ) if attack_step._attributes is None: # This is simply an empty inherited attack step continue - step_expressions = \ - attack_step._attributes['reaches']['stepExpressions'] if \ - attack_step._attributes['reaches'] else [] + step_expressions = ( + attack_step._attributes['reaches']['stepExpressions'] + if attack_step._attributes['reaches'] + else [] + ) for step_expression in step_expressions: # Resolve each of the attack step expressions listed for # this attack step to determine children. - (target_asset, expr_chain, target_attack_step_name) = \ + (target_asset, expr_chain, target_attack_step_name) = ( self.process_step_expression( - attack_step.asset, - None, - step_expression + attack_step.asset, None, step_expression ) + ) if not target_asset: - msg = 'Failed to find target asset to link with for ' \ + msg = ( + 'Failed to find target asset to link with for ' 'step expression:\n%s' + ) raise LanguageGraphStepExpressionError( - msg % json.dumps(step_expression, indent = 2) + msg % json.dumps(step_expression, indent=2) ) target_asset_attack_steps = target_asset.attack_steps - if target_attack_step_name not in \ - target_asset_attack_steps: - msg = 'Failed to find target attack step %s on %s to ' \ - 'link with for step expression:\n%s' + if target_attack_step_name not in target_asset_attack_steps: + msg = ( + 'Failed to find target attack step %s on %s to ' + 'link with for step expression:\n%s' + ) raise LanguageGraphStepExpressionError( - msg % ( + msg + % ( target_attack_step_name, target_asset.name, - json.dumps(step_expression, indent = 2) + json.dumps(step_expression, indent=2), ) ) target_attack_step = target_asset_attack_steps[ - target_attack_step_name] + target_attack_step_name + ] # Link to the children target attack steps if target_attack_step.full_name in attack_step.children: - attack_step.children[target_attack_step.full_name].\ - append((target_attack_step, expr_chain)) + attack_step.children[target_attack_step.full_name].append( + (target_attack_step, expr_chain) + ) else: - attack_step.children[target_attack_step.full_name] = \ - [(target_attack_step, expr_chain)] + attack_step.children[target_attack_step.full_name] = [ + (target_attack_step, expr_chain) + ] # Reverse the children associations chains to get the # parents associations chain. if attack_step.full_name in target_attack_step.parents: - target_attack_step.parents[attack_step.full_name].\ - append((attack_step, - self.reverse_expr_chain(expr_chain, - None))) + target_attack_step.parents[attack_step.full_name].append( + (attack_step, self.reverse_expr_chain(expr_chain, None)) + ) else: - target_attack_step.parents[attack_step.full_name] = \ - [(attack_step, - self.reverse_expr_chain(expr_chain, - None))] + target_attack_step.parents[attack_step.full_name] = [ + (attack_step, self.reverse_expr_chain(expr_chain, None)) + ] # Evaluate the requirements of exist and notExist attack steps - if attack_step.type == 'exist' or \ - attack_step.type == 'notExist': - step_expressions = \ - attack_step._attributes['requires']['stepExpressions'] \ - if attack_step._attributes['requires'] else [] + if attack_step.type in {'exist', 'notExist'}: + step_expressions = ( + attack_step._attributes['requires']['stepExpressions'] + if attack_step._attributes['requires'] + else [] + ) if not step_expressions: - msg = 'Failed to find requirements for attack step' \ - ' "%s" of type "%s":\n%s' + msg = ( + 'Failed to find requirements for attack step' + ' "%s" of type "%s":\n%s' + ) raise LanguageGraphStepExpressionError( - msg % ( + msg + % ( attack_step.name, attack_step.type, - json.dumps(attack_step._attributes, indent = 2) + json.dumps(attack_step._attributes, indent=2), ) ) attack_step.own_requires = [] for step_expression in step_expressions: - _, \ - result_expr_chain, \ - _ = \ - self.process_step_expression( - attack_step.asset, - None, - step_expression - ) + _, result_expr_chain, _ = self.process_step_expression( + attack_step.asset, None, step_expression + ) attack_step.own_requires.append( - self.reverse_expr_chain(result_expr_chain, None)) + self.reverse_expr_chain(result_expr_chain, None) + ) def _get_attacks_for_asset_type(self, asset_type: str) -> dict: - """ - Get all Attack Steps for a specific Class + """Get all Attack Steps for a specific Class. Arguments: asset_type - a string representing the class for which we want to @@ -1544,32 +1372,29 @@ def _get_attacks_for_asset_type(self, asset_type: str) -> dict: associated with a dictionary containing other characteristics of the attack such as type of attack, TTC distribution, child attack steps and other information + """ attack_steps: dict = {} try: asset = next( - asset for asset in self._lang_spec['assets'] \ - if asset['name'] == asset_type + asset + for asset in self._lang_spec['assets'] + if asset['name'] == asset_type ) except StopIteration: - logger.error( - 'Failed to find asset type %s when looking' - 'for attack steps.', asset_type + logger.exception( + 'Failed to find asset type %s when lookingfor attack steps.', asset_type ) return attack_steps logger.debug( - 'Get attack steps for %s asset from ' - 'language specification.', asset['name'] + 'Get attack steps for %s asset from language specification.', asset['name'] ) - attack_steps = {step['name']: step for step in asset['attackSteps']} - - return attack_steps + return {step['name']: step for step in asset['attackSteps']} def _get_associations_for_asset_type(self, asset_type: str) -> list: - """ - Get all Associations for a specific Class + """Get all Associations for a specific Class. Arguments: asset_type - a string representing the class for which we want to @@ -1581,37 +1406,43 @@ def _get_associations_for_asset_type(self, asset_type: str) -> list: with a dictionary containing other characteristics of the attack such as type of attack, TTC distribution, child attack steps and other information + """ logger.debug( - 'Get associations for %s asset from ' - 'language specification.', asset_type + 'Get associations for %s asset from language specification.', asset_type ) associations: list = [] - asset = next((asset for asset in self._lang_spec['assets'] \ - if asset['name'] == asset_type), None) + asset = next( + ( + asset + for asset in self._lang_spec['assets'] + if asset['name'] == asset_type + ), + None, + ) if not asset: logger.error( - 'Failed to find asset type %s when ' - 'looking for associations.', asset_type + 'Failed to find asset type %s when looking for associations.', + asset_type, ) return associations - assoc_iter = (assoc for assoc in self._lang_spec['associations'] \ - if assoc['leftAsset'] == asset_type or \ - assoc['rightAsset'] == asset_type) + assoc_iter = ( + assoc + for assoc in self._lang_spec['associations'] + if assoc['leftAsset'] == asset_type or assoc['rightAsset'] == asset_type + ) assoc = next(assoc_iter, None) - while (assoc): + while assoc: associations.append(assoc) assoc = next(assoc_iter, None) return associations - def _get_variables_for_asset_type( - self, asset_type: str) -> dict: - """ - Get a variables for a specific asset type by name. - Note: Variables are the ones specified in MAL through `let` statements + def _get_variables_for_asset_type(self, asset_type: str) -> dict: + """Get a variables for a specific asset type by name. + Note: Variables are the ones specified in MAL through `let` statements. Arguments: asset_type - a string representing the type of asset which @@ -1620,22 +1451,28 @@ def _get_variables_for_asset_type( Return: A dictionary representing the step expressions for the variables belonging to the asset. - """ - asset_dict = next((asset for asset in self._lang_spec['assets'] \ - if asset['name'] == asset_type), None) + """ + asset_dict = next( + ( + asset + for asset in self._lang_spec['assets'] + if asset['name'] == asset_type + ), + None, + ) if not asset_dict: - msg = 'Failed to find asset type %s in language specification '\ + msg = ( + 'Failed to find asset type %s in language specification ' 'when looking for variables.' + ) logger.error(msg, asset_type) raise LanguageGraphException(msg % asset_type) return asset_dict['variables'] - def _get_var_expr_for_asset( - self, asset_type: str, var_name) -> dict: - """ - Get a variable for a specific asset type by variable name. + def _get_var_expr_for_asset(self, asset_type: str, var_name) -> dict: + """Get a variable for a specific asset type by variable name. Arguments: asset_type - a string representing the type of asset which @@ -1644,27 +1481,31 @@ def _get_var_expr_for_asset( Return: A dictionary representing the step expression for the variable. - """ + """ vars_dict = self._get_variables_for_asset_type(asset_type) - var_expr = next((var_entry['stepExpression'] for var_entry \ - in vars_dict if var_entry['name'] == var_name), None) + var_expr = next( + ( + var_entry['stepExpression'] + for var_entry in vars_dict + if var_entry['name'] == var_name + ), + None, + ) if not var_expr: - msg = 'Failed to find variable name "%s" in language '\ + msg = ( + 'Failed to find variable name "%s" in language ' 'specification when looking for variables for "%s" asset.' + ) logger.error(msg, var_name, asset_type) raise LanguageGraphException(msg % (var_name, asset_type)) return var_expr def regenerate_graph(self) -> None: - """ - Regenerate language graph starting from the MAL language specification + """Regenerate language graph starting from the MAL language specification given in the constructor. """ - self.assets = {} self._generate_graph() - - diff --git a/maltoolbox/model.py b/maltoolbox/model.py index d7bed069..67a4aeb1 100644 --- a/maltoolbox/model.py +++ b/maltoolbox/model.py @@ -1,44 +1,45 @@ -""" -MAL-Toolbox Model Module -""" +"""MAL-Toolbox Model Module.""" from __future__ import annotations -from dataclasses import dataclass, field + import json import logging +from dataclasses import dataclass, field from typing import TYPE_CHECKING +from . import __version__ +from .exceptions import DuplicateModelAssociationError, ModelAssociationException from .file_utils import ( load_dict_from_json_file, load_dict_from_yaml_file, - save_dict_to_file + save_dict_to_file, ) -from . import __version__ -from .exceptions import DuplicateModelAssociationError, ModelAssociationException - if TYPE_CHECKING: - from typing import Any, Optional, TypeAlias - from .language import LanguageClassesFactory + from typing import Any, TypeAlias + from python_jsonschema_objects.classbuilder import ProtocolBase + from .language import LanguageClassesFactory + SchemaGeneratedClass: TypeAlias = ProtocolBase logger = logging.getLogger(__name__) + @dataclass class AttackerAttachment: - """Used to attach attackers to attack step entry points of assets""" - id: Optional[int] = None - name: Optional[str] = None - entry_points: list[tuple[SchemaGeneratedClass, list[str]]] = \ - field(default_factory=lambda: []) + """Used to attach attackers to attack step entry points of assets.""" + id: int | None = None + name: str | None = None + entry_points: list[tuple[SchemaGeneratedClass, list[str]]] = field( + default_factory=list + ) def get_entry_point_tuple( - self, - asset: SchemaGeneratedClass - ) -> Optional[tuple[SchemaGeneratedClass, list[str]]]: + self, asset: SchemaGeneratedClass + ) -> tuple[SchemaGeneratedClass, list[str]] | None: """Return an entry point tuple of an AttackerAttachment matching the asset provided. @@ -51,14 +52,16 @@ def get_entry_point_tuple( steps if the asset has any entry points defined for this attacker attachemnt. None, otherwise. - """ - return next((ep_tuple for ep_tuple in self.entry_points - if ep_tuple[0] == asset), None) + """ + return next( + (ep_tuple for ep_tuple in self.entry_points if ep_tuple[0] == asset), None + ) def add_entry_point( - self, asset: SchemaGeneratedClass, attackstep_name: str): - """Add an entry point to an AttackerAttachment + self, asset: SchemaGeneratedClass, attackstep_name: str + ) -> None: + """Add an entry point to an AttackerAttachment. self.entry_points contain tuples, first element of each tuple is an asset, second element is a list of attack step names that @@ -67,8 +70,8 @@ def add_entry_point( Arguments: asset - the asset to add the entry point to attackstep_name - the name of the attack step to add as an entry point - """ + """ logger.debug( f'Add entry point "{attackstep_name}" on asset "{asset.name}" ' f'to AttackerAttachment "{self.name}".' @@ -93,13 +96,14 @@ def add_entry_point( self.entry_points.append((asset, [attackstep_name])) def remove_entry_point( - self, asset: SchemaGeneratedClass, attackstep_name: str): - """Remove an entry point from an AttackerAttachment if it exists + self, asset: SchemaGeneratedClass, attackstep_name: str + ) -> None: + """Remove an entry point from an AttackerAttachment if it exists. Arguments: asset - the asset to remove the entry point from - """ + """ logger.debug( f'Remove entry point "{attackstep_name}" on asset "{asset.name}" ' f'from AttackerAttachment "{self.name}".' @@ -128,24 +132,24 @@ def remove_entry_point( ) -class Model(): - """An implementation of a MAL language with assets and associations""" +class Model: + """An implementation of a MAL language with assets and associations.""" + next_id: int = 0 def __repr__(self) -> str: return f'Model {self.name}' def __init__( - self, - name: str, - lang_classes_factory: LanguageClassesFactory, - mt_version: str = __version__ - ): - + self, + name: str, + lang_classes_factory: LanguageClassesFactory, + mt_version: str = __version__, + ) -> None: self.name = name self.assets: list[SchemaGeneratedClass] = [] self.associations: list[SchemaGeneratedClass] = [] - self._type_to_association:dict = {} # optimization + self._type_to_association: dict = {} # optimization self.attackers: list[AttackerAttachment] = [] self.lang_classes_factory: LanguageClassesFactory = lang_classes_factory self.maltoolbox_version: str = mt_version @@ -156,11 +160,11 @@ def __init__( self.asset_names: set[str] = set() def add_asset( - self, - asset: SchemaGeneratedClass, - asset_id: Optional[int] = None, - allow_duplicate_names: bool = True - ) -> None: + self, + asset: SchemaGeneratedClass, + asset_id: int | None = None, + allow_duplicate_names: bool = True, + ) -> None: """Add an asset to the model. Arguments: @@ -173,12 +177,13 @@ def add_asset( Return: An asset matching the name if it exists in the model. - """ + """ # Set asset ID and check for duplicates asset.id = asset_id or self.next_id if asset.id in self.asset_ids: - raise ValueError(f'Asset index {asset_id} already in use.') + msg = f'Asset index {asset_id} already in use.' + raise ValueError(msg) self.asset_ids.add(asset.id) self.next_id = max(asset.id + 1, self.next_id) @@ -187,28 +192,26 @@ def add_asset( if not hasattr(asset, 'name'): asset.name = asset.type + ':' + str(asset.id) - else: - if asset.name in self.asset_names: - if allow_duplicate_names: - asset.name = asset.name + ':' + str(asset.id) - else: - raise ValueError( - f'Asset name {asset.name} is a duplicate' - ' and we do not allow duplicates.' - ) + elif asset.name in self.asset_names: + if allow_duplicate_names: + asset.name = asset.name + ':' + str(asset.id) + else: + msg = ( + f'Asset name {asset.name} is a duplicate' + ' and we do not allow duplicates.' + ) + raise ValueError(msg) self.asset_names.add(asset.name) # Optional field for extra asset data if not hasattr(asset, 'extras'): asset.extras = {} - logger.debug( - 'Add "%s"(%d) to model "%s".', asset.name, asset.id, self.name - ) + logger.debug('Add "%s"(%d) to model "%s".', asset.name, asset.id, self.name) self.assets.append(asset) def remove_attacker(self, attacker: AttackerAttachment) -> None: - """Remove attacker""" + """Remove attacker.""" self.attackers.remove(attacker) def remove_asset(self, asset: SchemaGeneratedClass) -> None: @@ -216,17 +219,14 @@ def remove_asset(self, asset: SchemaGeneratedClass) -> None: Arguments: asset - the asset to remove - """ + """ logger.debug( - 'Remove "%s"(%d) from model "%s".', - asset.name, asset.id, self.name + 'Remove "%s"(%d) from model "%s".', asset.name, asset.id, self.name ) if asset not in self.assets: - raise LookupError( - f'Asset "{asset.name}"({asset.id}) is not part' - f' of model"{self.name}".' - ) + msg = f'Asset "{asset.name}"({asset.id}) is not part of model"{self.name}".' + raise LookupError(msg) # First remove all of the associations for association in asset.associations: @@ -241,35 +241,35 @@ def remove_asset(self, asset: SchemaGeneratedClass) -> None: self.assets.remove(asset) def remove_asset_from_association( - self, - asset: SchemaGeneratedClass, - association: SchemaGeneratedClass - ) -> None: + self, asset: SchemaGeneratedClass, association: SchemaGeneratedClass + ) -> None: """Remove an asset from an association and remove the association if any of the two sides is now empty. Arguments: asset - the asset to remove from the given association association - the association to remove the asset from - """ + """ logger.debug( 'Remove "%s"(%d) from association of type "%s".', - asset.name, asset.id, type(association) + asset.name, + asset.id, + type(association), ) if asset not in self.assets: - raise LookupError( - f'Asset "{asset.name}"({asset.id}) is not part of model ' - f'"{self.name}".' + msg = ( + f'Asset "{asset.name}"({asset.id}) is not part of model "{self.name}".' ) + raise LookupError(msg) if association not in self.associations: - raise LookupError( - f'Association is not part of model "{self.name}".' - ) + msg = f'Association is not part of model "{self.name}".' + raise LookupError(msg) - left_field_name, right_field_name = \ - self.get_association_field_names(association) + left_field_name, right_field_name = self.get_association_field_names( + association + ) left_field = getattr(association, left_field_name) right_field = getattr(association, right_field_name) found = False @@ -284,8 +284,11 @@ def remove_asset_from_association( field.remove(asset) if not found: - raise LookupError(f'Asset "{asset.name}"({asset.id}) is not ' - 'part of the association provided.') + msg = ( + f'Asset "{asset.name}"({asset.id}) is not ' + 'part of the association provided.' + ) + raise LookupError(msg) def _validate_association(self, association: SchemaGeneratedClass) -> None: """Raise error if association is invalid or already part of the Model. @@ -293,50 +296,48 @@ def _validate_association(self, association: SchemaGeneratedClass) -> None: Raises: DuplicateAssociationError - same association already exists ModelAssociationException - association is not valid - """ + """ # Optimization: only look for duplicates in associations of same type association_type = association.type - associations_same_type = self._type_to_association.get( - association_type, [] - ) + associations_same_type = self._type_to_association.get(association_type, []) # Check if identical association already exists if association in associations_same_type: - raise DuplicateModelAssociationError( - f"Identical association {association_type} already exists" - ) - + msg = f'Identical association {association_type} already exists' + raise DuplicateModelAssociationError(msg) # Check for duplicate assets in each field - left_field_name, right_field_name = \ - self.get_association_field_names(association) + left_field_name, right_field_name = self.get_association_field_names( + association + ) for field_name in (left_field_name, right_field_name): field_assets = getattr(association, field_name) unique_field_asset_names = {a.name for a in field_assets} if len(field_assets) > len(unique_field_asset_names): - raise ModelAssociationException( - "More than one asset share same name in field" - f"{association_type}.{field_name}" + msg = ( + 'More than one asset share same name in field' + f'{association_type}.{field_name}' ) + raise ModelAssociationException(msg) # For each asset in left field, go through each assets in right field # to find all unique connections. Raise error if a connection between # two assets already exist in a previously added association. for left_asset in getattr(association, left_field_name): for right_asset in getattr(association, right_field_name): - if self.association_exists_between_assets( association_type, left_asset, right_asset ): # Assets already have the connection in another # association with same type - raise DuplicateModelAssociationError( - f"Association type {association_type} already exists" - f" between {left_asset.name} and {right_asset.name}" + msg = ( + f'Association type {association_type} already exists' + f' between {left_asset.name} and {right_asset.name}' ) + raise DuplicateModelAssociationError(msg) def add_association(self, association: SchemaGeneratedClass) -> None: """Add an association to the model. @@ -352,7 +353,6 @@ def add_association(self, association: SchemaGeneratedClass) -> None: ModelAssociationException - association is not valid """ - # Check association is valid and not duplicate self._validate_association(association) @@ -372,25 +372,22 @@ def add_association(self, association: SchemaGeneratedClass) -> None: # Add association to type->association mapping association_type = association.type - self._type_to_association.setdefault( - association_type, [] - ).append(association) - + self._type_to_association.setdefault(association_type, []).append(association) def remove_association(self, association: SchemaGeneratedClass) -> None: """Remove an association from the model. Arguments: association - the association to remove from the model - """ + """ if association not in self.associations: - raise LookupError( - f'Association is not part of model "{self.name}".' - ) + msg = f'Association is not part of model "{self.name}".' + raise LookupError(msg) - left_field_name, right_field_name = \ - self.get_association_field_names(association) + left_field_name, right_field_name = self.get_association_field_names( + association + ) left_field = getattr(association, left_field_name) right_field = getattr(association, right_field_name) @@ -412,25 +409,21 @@ def remove_association(self, association: SchemaGeneratedClass) -> None: # Remove association from type->association mapping association_type = association.type - self._type_to_association[association_type].remove( - association - ) + self._type_to_association[association_type].remove(association) # Remove type from type->association mapping if mapping empty if len(self._type_to_association[association_type]) == 0: del self._type_to_association[association_type] def add_attacker( - self, - attacker: AttackerAttachment, - attacker_id: Optional[int] = None - ) -> None: + self, attacker: AttackerAttachment, attacker_id: int | None = None + ) -> None: """Add an attacker to the model. Arguments: attacker - the attacker to add attacker_id - optional id for the attacker - """ + """ if attacker_id is not None: attacker.id = attacker_id else: @@ -441,136 +434,113 @@ def add_attacker( attacker.name = 'Attacker:' + str(attacker.id) self.attackers.append(attacker) - def get_asset_by_id( - self, asset_id: int - ) -> Optional[SchemaGeneratedClass]: - """ - Find an asset in the model based on its id. + def get_asset_by_id(self, asset_id: int) -> SchemaGeneratedClass | None: + """Find an asset in the model based on its id. Arguments: asset_id - the id of the asset we are looking for Return: An asset matching the id if it exists in the model. - """ - logger.debug( - 'Get asset with id %d from model "%s".', - asset_id, self.name - ) - return next( - (asset for asset in self.assets - if asset.id == asset_id), None - ) - def get_asset_by_name( - self, asset_name: str - ) -> Optional[SchemaGeneratedClass]: """ - Find an asset in the model based on its name. + logger.debug('Get asset with id %d from model "%s".', asset_id, self.name) + return next((asset for asset in self.assets if asset.id == asset_id), None) + + def get_asset_by_name(self, asset_name: str) -> SchemaGeneratedClass | None: + """Find an asset in the model based on its name. Arguments: asset_name - the name of the asset we are looking for Return: An asset matching the name if it exists in the model. - """ - logger.debug( - 'Get asset with name "%s" from model "%s".', - asset_name, self.name - ) - return next( - (asset for asset in self.assets - if asset.name == asset_name), None - ) - def get_attacker_by_id( - self, attacker_id: int - ) -> Optional[AttackerAttachment]: """ - Find an attacker in the model based on its id. + logger.debug('Get asset with name "%s" from model "%s".', asset_name, self.name) + return next((asset for asset in self.assets if asset.name == asset_name), None) + + def get_attacker_by_id(self, attacker_id: int) -> AttackerAttachment | None: + """Find an attacker in the model based on its id. Arguments: attacker_id - the id of the attacker we are looking for Return: An attacker matching the id if it exists in the model. + """ - logger.debug( - 'Get attacker with id %d from model "%s".', - attacker_id, self.name - ) + logger.debug('Get attacker with id %d from model "%s".', attacker_id, self.name) return next( - (attacker for attacker in self.attackers - if attacker.id == attacker_id), None - ) + (attacker for attacker in self.attackers if attacker.id == attacker_id), + None, + ) def association_exists_between_assets( - self, - association_type: str, - left_asset: SchemaGeneratedClass, - right_asset: SchemaGeneratedClass - ): - """Return True if the association already exists between the assets""" + self, + association_type: str, + left_asset: SchemaGeneratedClass, + right_asset: SchemaGeneratedClass, + ) -> bool: + """Return True if the association already exists between the assets.""" logger.debug( 'Check to see if an association of type "%s" ' 'already exists between "%s" and "%s".', - association_type, left_asset.name, right_asset.name + association_type, + left_asset.name, + right_asset.name, ) associations = self._type_to_association.get(association_type, []) for association in associations: - left_field_name, right_field_name = \ - self.get_association_field_names(association) - if (left_asset.id in [asset.id for asset in \ - getattr(association, left_field_name)] and \ - right_asset.id in [asset.id for asset in \ - getattr(association, right_field_name)]): - logger.debug( - 'An association of type "%s" ' - 'already exists between "%s" and "%s".', - association_type, left_asset.name, right_asset.name - ) - return True + left_field_name, right_field_name = self.get_association_field_names( + association + ) + if left_asset.id in [ + asset.id for asset in getattr(association, left_field_name) + ] and right_asset.id in [ + asset.id for asset in getattr(association, right_field_name) + ]: + logger.debug( + 'An association of type "%s" already exists between "%s" and "%s".', + association_type, + left_asset.name, + right_asset.name, + ) + return True logger.debug( - 'No association of type "%s" ' - 'exists between "%s" and "%s".', - association_type, left_asset.name, right_asset.name + 'No association of type "%s" exists between "%s" and "%s".', + association_type, + left_asset.name, + right_asset.name, ) return False def get_asset_defenses( - self, - asset: SchemaGeneratedClass, - include_defaults: bool = False - ): - """ - Get the two field names of the association as a list. + self, asset: SchemaGeneratedClass, include_defaults: bool = False + ): + """Get the two field names of the association as a list. + Arguments: asset - the asset to fetch the defenses for include_defaults - if not True the defenses that have default - values will not be included in the list + values will not be included in the list. Return: A dictionary containing the defenses of the asset - """ + """ defenses = {} for key, value in asset._properties.items(): - property_schema = ( - self.lang_classes_factory.json_schema['definitions'] - ['LanguageAsset'] ['definitions'] - ['Asset_' + asset.type]['properties'][key] - ) + property_schema = self.lang_classes_factory.json_schema['definitions'][ + 'LanguageAsset' + ]['definitions']['Asset_' + asset.type]['properties'][key] - if "maximum" not in property_schema: + if 'maximum' not in property_schema: # Check if property is a defense by looking up defense # specific key. Skip if it is not a defense. continue - logger.debug( - 'Translating %s: %s defense to dictionary.', - key, - value - ) + logger.debug('Translating %s: %s defense to dictionary.', key, value) if not include_defaults and value == value.default(): # Skip the defense values if they are the default ones. @@ -580,29 +550,22 @@ def get_asset_defenses( return defenses - def get_association_field_names( - self, - association: SchemaGeneratedClass - ): - """ - Get the two field names of the association as a list. + def get_association_field_names(self, association: SchemaGeneratedClass): + """Get the two field names of the association as a list. + Arguments: - association - the association to fetch the field names for + association - the association to fetch the field names for. Return: A two item list containing the field names of the association. - """ + """ return list(association._properties.keys())[1:] - def get_associated_assets_by_field_name( - self, - asset: SchemaGeneratedClass, - field_name: str - ) -> list[SchemaGeneratedClass]: - """ - Get a list of associated assets for an asset given a field name. + self, asset: SchemaGeneratedClass, field_name: str + ) -> list[SchemaGeneratedClass]: + """Get a list of associated assets for an asset given a field name. Arguments: asset - the asset whose fields we are interested in @@ -611,10 +574,13 @@ def get_associated_assets_by_field_name( Return: A list of assets associated with the asset given that match the field_name. + """ logger.debug( 'Get associated assets for asset "%s"(%d) by field name %s.', - asset.name, asset.id, field_name + asset.name, + asset.id, + field_name, ) associated_assets = [] for association in asset.associations: @@ -630,19 +596,11 @@ def asset_to_dict(self, asset: SchemaGeneratedClass) -> tuple[str, dict]: asset - asset to get dictionary representation of Return: tuple with name of asset and the asset as dict - """ - - logger.debug( - 'Translating "%s"(%d) to dictionary.', - asset.name, - asset.id - ) + """ + logger.debug('Translating "%s"(%d) to dictionary.', asset.name, asset.id) - asset_dict: dict[str, Any] = { - 'name': str(asset.name), - 'type': str(asset.type) - } + asset_dict: dict[str, Any] = {'name': str(asset.name), 'type': str(asset.type)} defenses = self.get_asset_defenses(asset) @@ -655,7 +613,6 @@ def asset_to_dict(self, asset: SchemaGeneratedClass) -> tuple[str, dict]: return (asset.id, asset_dict) - def association_to_dict(self, association: SchemaGeneratedClass) -> dict: """Get dictionary representation of the association. @@ -663,20 +620,22 @@ def association_to_dict(self, association: SchemaGeneratedClass) -> dict: association - association to get dictionary representation of Returns the association serialized to a dict - """ - left_field_name, right_field_name = \ - self.get_association_field_names(association) + """ + left_field_name, right_field_name = self.get_association_field_names( + association + ) left_field = getattr(association, left_field_name) right_field = getattr(association, right_field_name) association_dict = { - str(association.type) : - { - str(left_field_name): - {int(asset.id): str(asset.name) for asset in left_field}, - str(right_field_name): - {int(asset.id): str(asset.name) for asset in right_field} + str(association.type): { + str(left_field_name): { + int(asset.id): str(asset.name) for asset in left_field + }, + str(right_field_name): { + int(asset.id): str(asset.name) for asset in right_field + }, } } @@ -686,24 +645,22 @@ def association_to_dict(self, association: SchemaGeneratedClass) -> dict: return association_dict - def attacker_to_dict( - self, attacker: AttackerAttachment - ) -> tuple[Optional[int], dict]: + def attacker_to_dict(self, attacker: AttackerAttachment) -> tuple[int | None, dict]: """Get dictionary representation of the attacker. Arguments: attacker - attacker to get dictionary representation of - """ + """ logger.debug('Translating %s to dictionary.', attacker.name) attacker_dict: dict[str, Any] = { 'name': str(attacker.name), 'entry_points': {}, } - for (asset, attack_steps) in attacker.entry_points: + for asset, attack_steps in attacker.entry_points: attacker_dict['entry_points'][str(asset.name)] = { 'asset_id': int(asset.id), - 'attack_steps' : attack_steps + 'attack_steps': attack_steps, } return (attacker.id, attacker_dict) @@ -714,7 +671,7 @@ def _to_dict(self) -> dict: 'metadata': {}, 'assets': {}, 'associations': [], - 'attackers' : {} + 'attackers': {}, } contents['metadata'] = { 'name': self.name, @@ -722,7 +679,7 @@ def _to_dict(self) -> dict: 'langID': self.lang_classes_factory.lang_graph.metadata['id'], 'malVersion': '0.1.0-SNAPSHOT', 'MAL-Toolbox Version': __version__, - 'info': 'Created by the mal-toolbox model python module.' + 'info': 'Created by the mal-toolbox model python module.', } logger.debug('Translating assets to dictionary.') @@ -742,68 +699,64 @@ def _to_dict(self) -> dict: return contents def save_to_file(self, filename: str) -> None: - """Save to json/yml depending on extension""" + """Save to json/yml depending on extension.""" logger.debug('Save instance model to file "%s".', filename) return save_dict_to_file(filename, self._to_dict()) @classmethod def _from_dict( - cls, - serialized_object: dict, - lang_classes_factory: LanguageClassesFactory - ) -> Model: - """Create a model from dict representation + cls, serialized_object: dict, lang_classes_factory: LanguageClassesFactory + ) -> Model: + """Create a model from dict representation. Arguments: serialized_object - Model in dict format lang_classes_factory - - """ - maltoolbox_version = serialized_object['metadata']['MAL Toolbox Version'] \ - if 'MAL Toolbox Version' in serialized_object['metadata'] \ - else __version__ + """ + maltoolbox_version = serialized_object['metadata'].get( + 'MAL Toolbox Version', __version__ + ) model = Model( serialized_object['metadata']['name'], lang_classes_factory, - mt_version = maltoolbox_version) + mt_version=maltoolbox_version, + ) # Reconstruct the assets for asset_id, asset_object in serialized_object['assets'].items(): - if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug - logger.debug( - "Loading asset:\n%s", json.dumps(asset_object, indent=2) - ) + logger.debug('Loading asset:\n%s', json.dumps(asset_object, indent=2)) # Allow defining an asset via type only. asset_object = ( asset_object if isinstance(asset_object, dict) - else { - 'type': asset_object, - 'name': f"{asset_object}:{asset_id}" - } + else {'type': asset_object, 'name': f'{asset_object}:{asset_id}'} ) asset_type_class = model.lang_classes_factory.get_asset_class( - asset_object['type']) + asset_object['type'] + ) # TODO: remove this when factory goes away asset_type_class.__hash__ = lambda self: hash(self.name) # type: ignore[method-assign,misc] if asset_type_class is None: - raise LookupError('Failed to find asset "%s" in language' - ' classes factory' % asset_object['type']) - asset = asset_type_class(name = asset_object['name']) + msg = 'Failed to find asset "{}" in language classes factory'.format( + asset_object['type'] + ) + raise LookupError(msg) + asset = asset_type_class(name=asset_object['name']) if 'extras' in asset_object: asset.extras = asset_object['extras'] - for defense in (defenses:=asset_object.get('defenses', [])): + for defense in (defenses := asset_object.get('defenses', [])): setattr(asset, defense, float(defenses[defense])) - model.add_asset(asset, asset_id = int(asset_id)) + model.add_asset(asset, asset_id=int(asset_id)) # Reconstruct the associations for assoc_entry in serialized_object.get('associations', []): @@ -811,23 +764,27 @@ def _from_dict( assoc_keys_iter = iter(assoc_fields) field1 = next(assoc_keys_iter) field2 = next(assoc_keys_iter) - assoc_type_class = model.lang_classes_factory.\ - get_association_class_by_fieldnames(assoc, field1, field2) + assoc_type_class = ( + model.lang_classes_factory.get_association_class_by_fieldnames( + assoc, field1, field2 + ) + ) if assoc_type_class is None: - raise LookupError('Failed to find association "%s" with ' - 'fields "%s" and "%s" in language classes factory' % - (assoc, field1, field2) + msg = ( + f'Failed to find association "{assoc}" with ' + f'fields "{field1}" and "{field2}" in language classes factory' ) + raise LookupError(msg) association = assoc_type_class() for field, targets in assoc_fields.items(): setattr( association, field, - [model.get_asset_by_id(int(id)) for id in targets] + [model.get_asset_by_id(int(id)) for id in targets], ) - #TODO Properly handle extras + # TODO Properly handle extras model.add_association(association) @@ -835,27 +792,25 @@ def _from_dict( if 'attackers' in serialized_object: attackers_info = serialized_object['attackers'] for attacker_id in attackers_info: - attacker = AttackerAttachment(name = attackers_info[attacker_id]['name']) + attacker = AttackerAttachment(name=attackers_info[attacker_id]['name']) attacker.entry_points = [] - for asset_name, entry_points_dict in \ - attackers_info[attacker_id]['entry_points'].items(): + for entry_points_dict in attackers_info[attacker_id][ + 'entry_points' + ].values(): attacker.entry_points.append( - ( - model.get_asset_by_id( - entry_points_dict['asset_id']), - entry_points_dict['attack_steps'] - ) + ( + model.get_asset_by_id(entry_points_dict['asset_id']), + entry_points_dict['attack_steps'], ) - model.add_attacker(attacker, attacker_id = int(attacker_id)) + ) + model.add_attacker(attacker, attacker_id=int(attacker_id)) return model @classmethod def load_from_file( - cls, - filename: str, - lang_classes_factory: LanguageClassesFactory - ) -> Model: - """Create from json or yaml file depending on file extension""" + cls, filename: str, lang_classes_factory: LanguageClassesFactory + ) -> Model: + """Create from json or yaml file depending on file extension.""" logger.debug('Load instance model from file "%s".', filename) serialized_model = None if filename.endswith(('.yml', '.yaml')): @@ -863,5 +818,6 @@ def load_from_file( elif filename.endswith('.json'): serialized_model = load_dict_from_json_file(filename) else: - raise ValueError('Unknown file extension, expected json/yml/yaml') + msg = 'Unknown file extension, expected json/yml/yaml' + raise ValueError(msg) return cls._from_dict(serialized_model, lang_classes_factory) diff --git a/maltoolbox/translators/securicad.py b/maltoolbox/translators/securicad.py index 655d6639..889ff87f 100644 --- a/maltoolbox/translators/securicad.py +++ b/maltoolbox/translators/securicad.py @@ -1,26 +1,22 @@ -""" -MAL-Toolbox securiCAD Translator Module -""" +"""MAL-Toolbox securiCAD Translator Module.""" -import zipfile import json import logging import xml.etree.ElementTree as ET +import zipfile -from typing import Optional - -from ..model import AttackerAttachment, Model -from ..language import LanguageGraph, LanguageClassesFactory +from maltoolbox.language import LanguageClassesFactory, LanguageGraph +from maltoolbox.model import AttackerAttachment, Model logger = logging.getLogger(__name__) + def load_model_from_scad_archive( - scad_archive: str, - lang_graph: LanguageGraph, - lang_classes_factory: LanguageClassesFactory - ) -> Optional[Model]: - """ - Reads a '.sCAD' archive generated by securiCAD representing an instance + scad_archive: str, + lang_graph: LanguageGraph, + lang_classes_factory: LanguageClassesFactory, +) -> Model | None: + """Reads a '.sCAD' archive generated by securiCAD representing an instance model and loads the information into a maltoobox.model.Model object. Arguments: @@ -33,44 +29,41 @@ def load_model_from_scad_archive( Return: A maltoobox.model.Model object containing the instance model. + """ with zipfile.ZipFile(scad_archive, 'r') as archive: filelist = archive.namelist() - model_file = next(filter(lambda x: ( x[-4:] == '.eom'), filelist)) + model_file = next(filter(lambda x: (x[-4:] == '.eom'), filelist)) scad_model = archive.read(model_file) root = ET.fromstring(scad_model) - instance_model = Model(scad_archive, - lang_classes_factory) + instance_model = Model(scad_archive, lang_classes_factory) for child in root.iter('objects'): - if logger.isEnabledFor(logging.DEBUG): # Avoid running json.dumps when not in debug logger.debug( 'Loading asset from "%s": \n%s', - scad_archive, json.dumps(child.attrib, indent=2) + scad_archive, + json.dumps(child.attrib, indent=2), ) if child.attrib['metaConcept'] == 'Attacker': attacker_obj_id = int(child.attrib['id']) attacker_at = AttackerAttachment() attacker_at.entry_points = [] - instance_model.add_attacker( - attacker_at, - attacker_id = attacker_obj_id - ) + instance_model.add_attacker(attacker_at, attacker_id=attacker_obj_id) continue - if not hasattr(lang_classes_factory.ns, - child.attrib['metaConcept']): + if not hasattr(lang_classes_factory.ns, child.attrib['metaConcept']): logger.error( 'Failed to find %s asset in language specification!', - child.attrib["metaConcept"] + child.attrib['metaConcept'], ) return None - asset = getattr(lang_classes_factory.ns, - child.attrib['metaConcept'])(name = child.attrib['name']) + asset = getattr(lang_classes_factory.ns, child.attrib['metaConcept'])( + name=child.attrib['name'] + ) asset_id = int(child.attrib['id']) for subchild in child.iter('evidenceAttributes'): defense_name = subchild.attrib['metaConcept'] @@ -85,9 +78,11 @@ def load_model_from_scad_archive( for child in root.iter('associations'): logger.debug( 'Load association ("%s", "%s", "%s", "%s") from %s', - child.attrib["sourceObject"], child.attrib["targetObject"], - child.attrib["targetProperty"], child.attrib["sourceProperty"], - scad_archive + child.attrib['sourceObject'], + child.attrib['targetObject'], + child.attrib['targetProperty'], + child.attrib['sourceProperty'], + scad_archive, ) # Note: This is not a bug in the code. The fields and assets are # listed incorrectly in the securiCAD format where the source asset @@ -108,32 +103,23 @@ def load_model_from_scad_archive( attacker = instance_model.get_attacker_by_id(attacker_id) if not attacker: logger.error( - 'Failed to find attacker with id %s in model!', - attacker_id + 'Failed to find attacker with id %s in model!', attacker_id ) return None target_asset = instance_model.get_asset_by_id(target_id) if not target_asset: - logger.error( - 'Failed to find asset with id %s in model!', - target_id - ) + logger.error('Failed to find asset with id %s in model!', target_id) return None - attacker.entry_points.append((target_asset, - [target_prop.split('.')[0]])) + attacker.entry_points.append((target_asset, [target_prop.split('.')[0]])) continue left_asset = instance_model.get_asset_by_id(left_id) if not left_asset: - logger.error( - 'Failed to find asset with id %s in model!', left_id - ) + logger.error('Failed to find asset with id %s in model!', left_id) return None right_asset = instance_model.get_asset_by_id(right_id) if not right_asset: - logger.error( - 'Failed to find asset with id %s in model!', right_id - ) + logger.error('Failed to find asset with id %s in model!', right_id) return None # Note: This is not a bug in the code. The fields and assets are @@ -143,33 +129,33 @@ def load_model_from_scad_archive( right_field = child.attrib['targetProperty'] lang_graph_assoc = None for assoc in left_asset.lg_asset.associations: - if (assoc.left_field.fieldname == left_field and - assoc.right_field.fieldname == right_field) or \ - (assoc.left_field.fieldname == right_field and - assoc.right_field.fieldname == left_field): + if ( + assoc.left_field.fieldname == left_field + and assoc.right_field.fieldname == right_field + ) or ( + assoc.left_field.fieldname == right_field + and assoc.right_field.fieldname == left_field + ): lang_graph_assoc = assoc break if not lang_graph_assoc: - raise LookupError( - 'Failed to find ("%s", "%s", "%s", "%s")' - 'association in lang specification.' % - (left_asset.type, right_asset.type, - left_field, right_field) + msg = ( + f'Failed to find ("{left_asset.type}", "{right_asset.type}", "{left_field}", "{right_field}")' + 'association in lang specification.' ) + raise LookupError(msg) return None logger.debug('Found "%s" association.', lang_graph_assoc.name) assoc_name = lang_classes_factory.get_association_by_signature( - lang_graph_assoc.name, - left_asset.type, - right_asset.type + lang_graph_assoc.name, left_asset.type, right_asset.type ) if assoc_name is None: logger.error( - 'Failed to find association with name \"%s\" in model!', - lang_graph_assoc.name + 'Failed to find association with name "%s" in model!', + lang_graph_assoc.name, ) return None diff --git a/maltoolbox/translators/updater.py b/maltoolbox/translators/updater.py index f0f92dfa..91ca4d79 100644 --- a/maltoolbox/translators/updater.py +++ b/maltoolbox/translators/updater.py @@ -3,37 +3,34 @@ import yaml -from ..model import Model, AttackerAttachment -from ..language import LanguageClassesFactory +from maltoolbox.language import LanguageClassesFactory +from maltoolbox.model import AttackerAttachment, Model logger = logging.getLogger(__name__) + def load_model_from_older_version( - filename: str, - lang_classes_factory: LanguageClassesFactory, - version: str - ) -> Model: - match (version): + filename: str, lang_classes_factory: LanguageClassesFactory, version: str +) -> Model: + match version: case '0.0.39': - return load_model_from_version_0_0_39(filename, - lang_classes_factory) + return load_model_from_version_0_0_39(filename, lang_classes_factory) case _: - msg = ('Unknown version "%s" format. Could not ' - 'load model from file "%s"') + msg = 'Unknown version "%s" format. Could not load model from file "%s"' logger.error(msg % (version, filename)) raise ValueError(msg % (version, filename)) + def load_model_from_version_0_0_39( - filename: str, - lang_classes_factory: LanguageClassesFactory - ) -> Model: - """ - Load model from file. + filename: str, lang_classes_factory: LanguageClassesFactory +) -> Model: + """Load model from file. Arguments: filename - the name of the input file lang_classes_factory - the language classes factory that defines the classes needed to build the model + """ def _process_model(model_dict, lang_classes_factory) -> Model: @@ -41,26 +38,29 @@ def _process_model(model_dict, lang_classes_factory) -> Model: # Reconstruct the assets for asset_id, asset_object in model_dict['assets'].items(): - logger.debug(f"Loading asset:\n{json.dumps(asset_object, indent=2)}") + logger.debug(f'Loading asset:\n{json.dumps(asset_object, indent=2)}') # Allow defining an asset via the metaconcept only. asset_object = ( asset_object if isinstance(asset_object, dict) - else {'metaconcept': asset_object, 'name': f"{asset_object}:{asset_id}"} + else {'metaconcept': asset_object, 'name': f'{asset_object}:{asset_id}'} ) - asset = getattr(model.lang_classes_factory.ns, - asset_object['metaconcept'])(name = asset_object['name']) + asset = getattr(model.lang_classes_factory.ns, asset_object['metaconcept'])( + name=asset_object['name'] + ) - for defense in (defenses:=asset_object.get('defenses', [])): + for defense in (defenses := asset_object.get('defenses', [])): setattr(asset, defense, float(defenses[defense])) - model.add_asset(asset, asset_id = int(asset_id)) + model.add_asset(asset, asset_id=int(asset_id)) # Reconstruct the associations for assoc_dict in model_dict.get('associations', []): - association = getattr(model.lang_classes_factory.ns, assoc_dict.pop('metaconcept'))() + association = getattr( + model.lang_classes_factory.ns, assoc_dict.pop('metaconcept') + )() # compatibility with old format assoc_dict = assoc_dict.get('association', assoc_dict) @@ -70,7 +70,7 @@ def _process_model(model_dict, lang_classes_factory) -> Model: setattr( association, field, - [model.get_asset_by_id(int(id)) for id in targets] + [model.get_asset_by_id(int(id)) for id in targets], ) model.add_association(association) @@ -78,55 +78,54 @@ def _process_model(model_dict, lang_classes_factory) -> Model: if 'attackers' in model_dict: attackers_info = model_dict['attackers'] for attacker_id in attackers_info: - attacker = AttackerAttachment( - name = attackers_info[attacker_id]['name'] - ) + attacker = AttackerAttachment(name=attackers_info[attacker_id]['name']) attacker.entry_points = [] for asset_id in attackers_info[attacker_id]['entry_points']: attacker.entry_points.append( - (model.get_asset_by_id(int(asset_id)), - attackers_info[attacker_id]['entry_points']\ - [asset_id]['attack_steps'])) - model.add_attacker(attacker, attacker_id = int(attacker_id)) + ( + model.get_asset_by_id(int(asset_id)), + attackers_info[attacker_id]['entry_points'][asset_id][ + 'attack_steps' + ], + ) + ) + model.add_attacker(attacker, attacker_id=int(attacker_id)) return model def load_from_json( - filename: str, - lang_classes_factory: LanguageClassesFactory - ) -> Model: - """ - Load model from a json file. + filename: str, lang_classes_factory: LanguageClassesFactory + ) -> Model: + """Load model from a json file. Arguments: filename - the name of the input file + """ - with open(filename, 'r', encoding='utf-8') as model_file: + with open(filename, encoding='utf-8') as model_file: model_dict = json.loads(model_file.read()) return _process_model(model_dict, lang_classes_factory) def load_from_yaml( - filename: str, - lang_classes_factory: LanguageClassesFactory - ) -> Model: - """ - Load model from a yaml file. + filename: str, lang_classes_factory: LanguageClassesFactory + ) -> Model: + """Load model from a yaml file. Arguments: filename - the name of the input file + """ - with open(filename, 'r', encoding='utf-8') as model_file: + with open(filename, encoding='utf-8') as model_file: model_dict = yaml.safe_load(model_file) return _process_model(model_dict, lang_classes_factory) logger.info(f'Loading model from {filename} file.') - if filename.endswith('.yml') or filename.endswith('.yaml'): + if filename.endswith(('.yml', '.yaml')): return load_from_yaml(filename, lang_classes_factory) - elif filename.endswith('.json'): + if filename.endswith('.json'): return load_from_json(filename, lang_classes_factory) - else: - msg = 'Unknown file extension for model file to load from.' - logger.error(msg) - raise ValueError(msg) + msg = 'Unknown file extension for model file to load from.' + logger.error(msg) + raise ValueError(msg) return None diff --git a/maltoolbox/wrappers.py b/maltoolbox/wrappers.py index c25b47d7..4c7dcda4 100644 --- a/maltoolbox/wrappers.py +++ b/maltoolbox/wrappers.py @@ -1,34 +1,33 @@ -"""Contains wrappers combining more than one of the maltoolbox submodules""" +"""Contains wrappers combining more than one of the maltoolbox submodules.""" import logging import sys import zipfile -from maltoolbox.model import Model -from maltoolbox.language import LanguageGraph, LanguageClassesFactory +from maltoolbox import log_configs from maltoolbox.attackgraph import AttackGraph -from maltoolbox.attackgraph.analyzers.apriori import ( - calculate_viability_and_necessity -) +from maltoolbox.attackgraph.analyzers.apriori import calculate_viability_and_necessity from maltoolbox.exceptions import AttackGraphStepExpressionError -from maltoolbox import log_configs - +from maltoolbox.language import LanguageClassesFactory, LanguageGraph +from maltoolbox.model import Model logger = logging.getLogger(__name__) + def create_attack_graph( - lang_file: str, - model_file: str, - attach_attackers=True, - calc_viability_and_necessity=True - ) -> AttackGraph: - """Create and return an attack graph + lang_file: str, + model_file: str, + attach_attackers=True, + calc_viability_and_necessity=True, +) -> AttackGraph: + """Create and return an attack graph. Args: lang_file - path to language file (.mar or .mal) model_file - path to model file (yaml or json) attach_attackers - whether to run attach_attackers or not calc_viability_and_necessity - whether run apriori calculations or not + """ try: lang_graph = LanguageGraph.from_mar_archive(lang_file) @@ -47,7 +46,7 @@ def create_attack_graph( try: attack_graph = AttackGraph(lang_graph, instance_model) except AttackGraphStepExpressionError: - logger.error( + logger.exception( 'Attack graph generation failed when attempting ' 'to resolve attack step expression!' ) diff --git a/pyproject.toml b/pyproject.toml index 617621f5..2ef4fef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mal-toolbox" -version = "0.2.0" +version = "0.1.12" authors = [ { name="Andrei Buhaiu", email="buhaiu@kth.se" }, { name="Giuseppe Nebbione", email="nebbione@kth.se" }, @@ -56,3 +56,164 @@ exclude = [ 'maltoolbox/language/compiler' ] allow_redefinition = true + +[tool.ruff] +# Same as TWMN version +target-version = "py310" + +[tool.ruff.format] +quote-style = "single" + +[tool.ruff.lint] +# Enable all available lints... +select = ["ALL"] +preview = true + +# ...except for these: +# NOTE: Some rules are disabled due to either conflicting with, or being made redundant by, +# formatter checks (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules) +# The following lines can be read as ` # ()` +ignore = [ + "ANN1", # Require type annotations for `self` (these are unnecessary) + "FIX", # Forbid TODO, FIXME etc. in the code (these should be reviewed individually) + "D203", # Require one blank line before class docstring (doctrings should be directly after class definition) + "D213", # Start multi-line docstring on the line after """ (our convention is to start on the same line) + "TD002", # Require author for TODOs (TODOs don't have to be explicitly assigned) + "TD003", # Require issue link following TODO (TODOs don't have to be related to an issue) + "Q", # Ensure proper quoting style (checked by formatter) + "W191", # Use tabs for indentation (formatter config disallows this) + "E111", # Ensure proper indentation style (checked by formatter) + "E114", # Ensure proper indentation style (checked by formatter) + "E117", # Ensure proper indentation style (checked by formatter) + "D206", # Use spaces for indentation (checked by formatter) + "D300", # Forbid the use of single quotes for docstrings (checked by formatter) + "COM812", # Require trailing commas where applicable (checked by formatter) + "COM819", # Forbid trailing comma on single line constructs (checked by formatter) + "ISC001", # Forbid implicitly concatenated string literals on one line (checked by formatter) + "ISC002", # Forbid implicitly concatenated string literals over multiple lines (checked by formatter) + "E501", # Forbid lines that are too long (checked by formatter) + # Additionally, the formatter may fail to wrap certain lines, thus using the linter for + # this rule instead of the formatter risks failing even when the formatter has been run + "INP001", # implicit-namespace-package, we do want to support namespace packages + "PLR0912", # C901 is a more general check of complexity + "S603", # unreliable check, also we don't need it (we don't handle user input) + "CPY001", # TWMN files don't contain license headers + "S404", # Warn about usage of subprocess for security concerns + # (1. we don't handle user input, 2. must sometimes use either way) + "DOC", # Ruff will soon replace pydoclint in our setup. However, it does not currently + # support the Sphinx-style doctrings we use, so until they do, these rules will + # yield incorrect results. See https://github.com/astral-sh/ruff/issues/12434. + + + # TODO: remove these one by one and fix the issues + "D102", + "ANN201", + "F405", + "N802", + "ANN001", + "D205", + "D417", + "PLR6301", + "D107", + "SLF001", + "N803", + "D106", + "PLC0205", + "PLR2004", + "N806", + "ERA001", + "D103", + "C901", + "N815", + "G004", + "D101", + "D105", + "PLR0915", + "RUF012", + "S101", + "S108", + "PLC0414", + "PTH123", + "PLR0914", + "FBT001", + "D100", + "TRY004", + "FBT002", + "G002", + "D401", + "PLW2901", + "PLR0913", + "PLR0917", + "FBT003", + "A001", + "PTH120", + "D104", + "D419", + "ANN204", + "A002", + "TD004", + "PTH118", + "N801", + "W505", + "PLR0904", + "PLR0911", + "YTT203", + "ANN202", + "PT011", + "PT018", + "SIM102", + "ARG001", + "PTH100", + "E402", + "E741", + "F402", + "F403", + "UP031", + "FURB101", + "ANN002", + "ANN003", + "ANN401", + "S314", + "S405", + "B018", + "EXE001", + "ISC003", + "G003", + "ARG002", + "PTH103", + "PTH119", + "N818", + "E743", + "PLC1901", + "PLR1702", +] + +# Allow unused variables only if they're named `_`. +dummy-variable-rgx = "^_$" + +# Try to fix all that it can +fixable = ["ALL"] + +# Override certain rules based on file path +[tool.ruff.lint.per-file-ignores] +# Test files +"**/tests/*" = [ + "PLR2004", # Allow magic value comparisons + "S101", # Allow `assert` statements + "S311", # Allow use of the `random` module (no need for cryptographic security) +] + +# Don't require superfluous parentheses in pytest fixures. +# See https://github.com/astral-sh/ruff/pull/12106 for discussion. +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false + +[tool.ruff.lint.isort] +split-on-trailing-comma = true +known-first-party = ["maltoolbox"] + +# Unless this is explicitly set, Ruff won't check for long lines in docstrings. +# Thus, set this to the same ass Ruff's default length (88). +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 diff --git a/tests/attackgraph/test_analyzer.py b/tests/attackgraph/test_analyzer.py index e6124d47..b7d34a9e 100644 --- a/tests/attackgraph/test_analyzer.py +++ b/tests/attackgraph/test_analyzer.py @@ -1,81 +1,79 @@ -"""Tests for analyzers""" +"""Tests for analyzers.""" from maltoolbox.attackgraph import AttackGraphNode -from maltoolbox.attackgraph.analyzers.apriori import propagate_viability_from_unviable_node +from maltoolbox.attackgraph.analyzers.apriori import ( + propagate_viability_from_unviable_node, +) # Apriori analyzer # TODO: Add apriori analyzer test implementations -def test_analyzers_apriori_propagate_viability_from_node(): - """See if viability is propagated correctly""" - pass +def test_analyzers_apriori_propagate_viability_from_node() -> None: + """See if viability is propagated correctly.""" -def test_analyzers_apriori_propagate_necessity_from_node(): - """See if necessity is propagated correctly""" - pass + +def test_analyzers_apriori_propagate_necessity_from_node() -> None: + """See if necessity is propagated correctly.""" -def test_analyzers_apriori_evaluate_viability(): +def test_analyzers_apriori_evaluate_viability() -> None: pass -def test_analyzers_apriori_evaluate_necessity(): +def test_analyzers_apriori_evaluate_necessity() -> None: pass -def test_analyzers_apriori_evaluate_viability_and_necessity(): +def test_analyzers_apriori_evaluate_viability_and_necessity() -> None: pass -def test_analyzers_apriori_calculate_viability_and_necessity(): +def test_analyzers_apriori_calculate_viability_and_necessity() -> None: pass -def test_analyzers_apriori_prune_unviable_and_unnecessary_nodes(): +def test_analyzers_apriori_prune_unviable_and_unnecessary_nodes() -> None: pass -def test_analyzers_apriori_propagate_viability_from_unviable_node(): - r"""Create a graph from nodes - node1 - / \ +def test_analyzers_apriori_propagate_viability_from_unviable_node() -> None: + r"""Create a graph from nodes. + + node1 + / \ node2 node3 - / \ / \ + / \ / \ node4 node5 node6 """ - # Create a graph of nodes according to above diagram node1 = AttackGraphNode( - type = "defense", - name = "node1", - lang_graph_attack_step = None, + type='defense', + name='node1', + lang_graph_attack_step=None, ) node2 = AttackGraphNode( - type = "or", - name = "node2", - lang_graph_attack_step = None, + type='or', + name='node2', + lang_graph_attack_step=None, ) node3 = AttackGraphNode( - type = "or", - name = "node3", - lang_graph_attack_step = None, - defense_status=0.0 + type='or', name='node3', lang_graph_attack_step=None, defense_status=0.0 ) node4 = AttackGraphNode( - type = "or", - name = "node4", - lang_graph_attack_step = None, + type='or', + name='node4', + lang_graph_attack_step=None, ) node5 = AttackGraphNode( - type = "or", - name = "node5", - lang_graph_attack_step = None, + type='or', + name='node5', + lang_graph_attack_step=None, ) node6 = AttackGraphNode( - type = "or", - name = "node6", - lang_graph_attack_step = None, + type='or', + name='node6', + lang_graph_attack_step=None, ) node1.id = 1 @@ -98,7 +96,11 @@ def test_analyzers_apriori_propagate_viability_from_unviable_node(): node1.is_viable = False unviable_nodes = propagate_viability_from_unviable_node(node1) unviable_node_names = {node.name for node in unviable_nodes} - expected_unviable_node_names = set( - [node2.name, node3.name, node4.name, node5.name, node6.name] - ) + expected_unviable_node_names = { + node2.name, + node3.name, + node4.name, + node5.name, + node6.name, + } assert unviable_node_names == expected_unviable_node_names diff --git a/tests/attackgraph/test_attacker.py b/tests/attackgraph/test_attacker.py index 2eaeb6d6..54dc56b6 100644 --- a/tests/attackgraph/test_attacker.py +++ b/tests/attackgraph/test_attacker.py @@ -1,43 +1,43 @@ -"""Unit tests for AttackGraphNode functionality""" +"""Unit tests for AttackGraphNode functionality.""" -from maltoolbox.attackgraph.node import AttackGraphNode from maltoolbox.attackgraph.attacker import Attacker from maltoolbox.attackgraph.attackgraph import AttackGraph +from maltoolbox.attackgraph.node import AttackGraphNode from maltoolbox.language import LanguageGraph -def test_attacker_to_dict(dummy_lang_graph: LanguageGraph): - """Test Attacker to dict conversion""" - dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].\ - attack_steps['DummyAttackStep'] +def test_attacker_to_dict(dummy_lang_graph: LanguageGraph) -> None: + """Test Attacker to dict conversion.""" + dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].attack_steps[ + 'DummyAttackStep' + ] node1 = AttackGraphNode( - type = "or", - name = "node1", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node1', + lang_graph_attack_step=dummy_attack_step, ) - attacker = Attacker("Test Attacker", [], [node1]) + attacker = Attacker('Test Attacker', [], [node1]) assert attacker.to_dict() == { - "id": None, - "name": "Test Attacker", - "entry_points": {}, - "reached_attack_steps": { - node1.id : str(node1.id) + ':' + node1.name - } + 'id': None, + 'name': 'Test Attacker', + 'entry_points': {}, + 'reached_attack_steps': {node1.id: str(node1.id) + ':' + node1.name}, } -def test_attacker_compromise(dummy_lang_graph: LanguageGraph): - """Attack a node and see expected behavior""" - dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].\ - attack_steps['DummyAttackStep'] +def test_attacker_compromise(dummy_lang_graph: LanguageGraph) -> None: + """Attack a node and see expected behavior.""" + dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].attack_steps[ + 'DummyAttackStep' + ] node1 = AttackGraphNode( - type = "or", - name = "node1", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node1', + lang_graph_attack_step=dummy_attack_step, ) - attacker = Attacker("Test Attacker", [], []) + attacker = Attacker('Test Attacker', [], []) assert not attacker.entry_points attack_graph = AttackGraph(dummy_lang_graph) attack_graph.add_node(node1) @@ -49,22 +49,23 @@ def test_attacker_compromise(dummy_lang_graph: LanguageGraph): assert node1.compromised_by == [attacker] - attacker.compromise(node1) # Compromise same node again not a problem + attacker.compromise(node1) # Compromise same node again not a problem assert attacker.reached_attack_steps == [node1] assert node1.compromised_by == [attacker] -def test_attacker_undo_compromise(dummy_lang_graph: LanguageGraph): - """Make sure undo compromise removes attacker/node""" - dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].\ - attack_steps['DummyAttackStep'] +def test_attacker_undo_compromise(dummy_lang_graph: LanguageGraph) -> None: + """Make sure undo compromise removes attacker/node.""" + dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].attack_steps[ + 'DummyAttackStep' + ] node1 = AttackGraphNode( - type = "or", - name = "node1", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node1', + lang_graph_attack_step=dummy_attack_step, ) - attacker = Attacker("attacker1", [], []) + attacker = Attacker('attacker1', [], []) attack_graph = AttackGraph(dummy_lang_graph) attack_graph.add_node(node1) attack_graph.add_attacker(attacker) @@ -72,7 +73,7 @@ def test_attacker_undo_compromise(dummy_lang_graph: LanguageGraph): attacker.compromise(node1) assert attacker.reached_attack_steps == [node1] assert node1.compromised_by == [attacker] - attacker.compromise(node1) # Compromise same node again not a problem + attacker.compromise(node1) # Compromise same node again not a problem assert attacker.reached_attack_steps == [node1] assert node1.compromised_by == [attacker] diff --git a/tests/attackgraph/test_attackgraph.py b/tests/attackgraph/test_attackgraph.py index 993b451f..ec1cc72b 100644 --- a/tests/attackgraph/test_attackgraph.py +++ b/tests/attackgraph/test_attackgraph.py @@ -1,30 +1,29 @@ -"""Unit tests for AttackGraph functionality""" +"""Unit tests for AttackGraph functionality.""" import copy -import pytest from unittest.mock import patch -from maltoolbox.language import LanguageGraph, LanguageClassesFactory -from maltoolbox.language.compiler import MalCompiler -from maltoolbox.attackgraph import AttackGraph, AttackGraphNode, Attacker -from maltoolbox.model import Model, AttackerAttachment - +import pytest from test_model import create_application_asset, create_association +from maltoolbox.attackgraph import Attacker, AttackGraph, AttackGraphNode +from maltoolbox.language import LanguageClassesFactory, LanguageGraph +from maltoolbox.language.compiler import MalCompiler +from maltoolbox.model import AttackerAttachment, Model + @pytest.fixture def example_attackgraph(corelang_lang_graph: LanguageGraph, model: Model): """Fixture that generates an example attack graph - with unattached attacker + with unattached attacker. Uses coreLang specification and model with two applications with an association and an attacker to create and return an AttackGraph object """ - # Create 2 assets - app1 = create_application_asset(model, "Application 1") - app2 = create_application_asset(model, "Application 2") + app1 = create_application_asset(model, 'Application 1') + app2 = create_application_asset(model, 'Application 2') model.add_asset(app1) model.add_asset(app2) @@ -33,46 +32,33 @@ def example_attackgraph(corelang_lang_graph: LanguageGraph, model: Model): model.add_association(assoc) attacker = AttackerAttachment() - attacker.entry_points = [ - (app1, ['networkConnectUninspected']) - ] + attacker.entry_points = [(app1, ['networkConnectUninspected'])] model.add_attacker(attacker) - return AttackGraph( - lang_graph=corelang_lang_graph, - model=model - ) - + return AttackGraph(lang_graph=corelang_lang_graph, model=model) -def test_attackgraph_init(corelang_lang_graph, model): - """Test init with different params given""" +def test_attackgraph_init(corelang_lang_graph, model) -> None: + """Test init with different params given.""" # _generate_graph is called when langspec and model is given to init - with patch("maltoolbox.attackgraph.AttackGraph._generate_graph")\ - as _generate_graph: - AttackGraph( - lang_graph=corelang_lang_graph, - model=model - ) + with patch('maltoolbox.attackgraph.AttackGraph._generate_graph') as _generate_graph: + AttackGraph(lang_graph=corelang_lang_graph, model=model) assert _generate_graph.call_count == 1 # _generate_graph is not called when no model is given - with patch("maltoolbox.attackgraph.AttackGraph._generate_graph")\ - as _generate_graph: - AttackGraph( - lang_graph=corelang_lang_graph, - model=None - ) + with patch('maltoolbox.attackgraph.AttackGraph._generate_graph') as _generate_graph: + AttackGraph(lang_graph=corelang_lang_graph, model=None) assert _generate_graph.call_count == 0 + def attackgraph_save_load_no_model_given( - example_attackgraph: AttackGraph, - corelang_lang_graph: LanguageGraph, - attach_attackers: bool - ): + example_attackgraph: AttackGraph, + corelang_lang_graph: LanguageGraph, + attach_attackers: bool, +) -> None: """Save AttackGraph to a file and load it - Note: Will create file in /tmp""" - + Note: Will create file in /tmp. + """ reward = 1 node_with_reward_before = example_attackgraph.nodes[0] node_with_reward_before.extras['reward'] = reward @@ -81,12 +67,13 @@ def attackgraph_save_load_no_model_given( example_attackgraph.attach_attackers() # Save the example attack graph to /tmp - example_graph_path = "/tmp/example_graph.yml" + example_graph_path = '/tmp/example_graph.yml' example_attackgraph.save_to_file(example_graph_path) # Load the attack graph - loaded_attack_graph = AttackGraph.load_from_file(example_graph_path, - corelang_lang_graph) + loaded_attack_graph = AttackGraph.load_from_file( + example_graph_path, corelang_lang_graph + ) assert node_with_reward_before.id is not None node_with_reward_after = loaded_attack_graph.get_node_by_id( node_with_reward_before.id @@ -103,27 +90,27 @@ def attackgraph_save_load_no_model_given( # Loaded graph nodes will not have 'asset' since it does not have a model. for loaded_node in loaded_attack_graph.nodes: if not isinstance(loaded_node.id, int): - raise ValueError(f'Invalid node id for loaded node.') + msg = 'Invalid node id for loaded node.' + raise ValueError(msg) original_node = example_attackgraph.get_node_by_id(loaded_node.id) - assert original_node, \ - f'Failed to find original node for id {loaded_node.id}.' + assert original_node, f'Failed to find original node for id {loaded_node.id}.' # Convert loaded and original node to dicts loaded_node_dict = loaded_node.to_dict() original_node_dict = original_node.to_dict() for child in original_node_dict['children']: child_node = example_attackgraph.get_node_by_id(child) - assert child_node, \ - f'Failed to find child node for id {child}.' - original_node_dict['children'][child] = str(child_node.id) + \ - ":" + child_node.name + assert child_node, f'Failed to find child node for id {child}.' + original_node_dict['children'][child] = ( + str(child_node.id) + ':' + child_node.name + ) for parent in original_node_dict['parents']: parent_node = example_attackgraph.get_node_by_id(parent) - assert parent_node, \ - f'Failed to find parent node for id {parent}.' - original_node_dict['parents'][parent] = str(parent_node.id) + \ - ":" + parent_node.name + assert parent_node, f'Failed to find parent node for id {parent}.' + original_node_dict['parents'][parent] = ( + str(parent_node.id) + ':' + parent_node.name + ) # Remove key that is not expected to match. del original_node_dict['asset'] @@ -133,68 +120,65 @@ def attackgraph_save_load_no_model_given( for loaded_attacker in loaded_attack_graph.attackers: if not isinstance(loaded_attacker.id, int): - raise ValueError(f'Invalid attacker id for loaded attacker.') - original_attacker = example_attackgraph.get_attacker_by_id( - loaded_attacker.id) - assert original_attacker, \ + msg = 'Invalid attacker id for loaded attacker.' + raise ValueError(msg) + original_attacker = example_attackgraph.get_attacker_by_id(loaded_attacker.id) + assert original_attacker, ( f'Failed to find original attacker for id {loaded_attacker.id}.' + ) loaded_attacker_dict = loaded_attacker.to_dict() original_attacker_dict = original_attacker.to_dict() for step in original_attacker_dict['entry_points']: attack_step_name = original_attacker_dict['entry_points'][step] - attack_step_name = str(step) + ':' + \ - attack_step_name.split(':')[-1] + attack_step_name = str(step) + ':' + attack_step_name.split(':')[-1] original_attacker_dict['entry_points'][step] = attack_step_name for step in original_attacker_dict['reached_attack_steps']: - attack_step_name = \ - original_attacker_dict['reached_attack_steps'][step] - attack_step_name = str(step) + ':' + \ - attack_step_name.split(':')[-1] - original_attacker_dict['reached_attack_steps'][step] = \ - attack_step_name + attack_step_name = original_attacker_dict['reached_attack_steps'][step] + attack_step_name = str(step) + ':' + attack_step_name.split(':')[-1] + original_attacker_dict['reached_attack_steps'][step] = attack_step_name assert loaded_attacker_dict == original_attacker_dict + def test_attackgraph_save_load_no_model_given_without_attackers( - example_attackgraph: AttackGraph, - corelang_lang_graph: LanguageGraph - ): - attackgraph_save_load_no_model_given(example_attackgraph, - corelang_lang_graph, False) + example_attackgraph: AttackGraph, corelang_lang_graph: LanguageGraph +) -> None: + attackgraph_save_load_no_model_given( + example_attackgraph, corelang_lang_graph, False + ) + def test_attackgraph_save_load_no_model_given_with_attackers( - example_attackgraph: AttackGraph, - corelang_lang_graph: LanguageGraph - ): - attackgraph_save_load_no_model_given(example_attackgraph, - corelang_lang_graph, True) + example_attackgraph: AttackGraph, corelang_lang_graph: LanguageGraph +) -> None: + attackgraph_save_load_no_model_given(example_attackgraph, corelang_lang_graph, True) + def attackgraph_save_and_load_json_yml_model_given( - example_attackgraph: AttackGraph, - corelang_lang_graph: LanguageGraph, - attach_attackers: bool - ): + example_attackgraph: AttackGraph, + corelang_lang_graph: LanguageGraph, + attach_attackers: bool, +) -> None: """Try to save and load attack graph from json and yml with model given, - and make sure the dict represenation is the same (except for reward field) + and make sure the dict represenation is the same (except for reward field). """ - if attach_attackers: example_attackgraph.attach_attackers() - for attackgraph_path in ("/tmp/attackgraph.yml", "/tmp/attackgraph.json"): + for attackgraph_path in ('/tmp/attackgraph.yml', '/tmp/attackgraph.json'): example_attackgraph.save_to_file(attackgraph_path) loaded_attackgraph = AttackGraph.load_from_file( - attackgraph_path, - corelang_lang_graph, - model=example_attackgraph.model + attackgraph_path, corelang_lang_graph, model=example_attackgraph.model ) # Make sure model was 'attached' correctly assert loaded_attackgraph.model == example_attackgraph.model - for node_full_name, loaded_node_dict in \ - loaded_attackgraph._to_dict()['attack_steps'].items(): - original_node_dict = \ - example_attackgraph._to_dict()['attack_steps'][node_full_name] + for node_full_name, loaded_node_dict in loaded_attackgraph._to_dict()[ + 'attack_steps' + ].items(): + original_node_dict = example_attackgraph._to_dict()['attack_steps'][ + node_full_name + ] # Make sure nodes are the same (except for the excluded keys) assert loaded_node_dict == original_node_dict @@ -202,7 +186,7 @@ def attackgraph_save_and_load_json_yml_model_given( for node in loaded_attackgraph.nodes: # Make sure node gets an asset when loaded with model assert node.asset - assert node.full_name == node.asset.name + ":" + node.name + assert node.full_name == node.asset.name + ':' + node.name # Make sure node was added to lookup dict with correct id / name assert node.id is not None @@ -211,49 +195,48 @@ def attackgraph_save_and_load_json_yml_model_given( for loaded_attacker in loaded_attackgraph.attackers: if not isinstance(loaded_attacker.id, int): - raise ValueError(f'Invalid attacker id for loaded attacker.') + msg = 'Invalid attacker id for loaded attacker.' + raise ValueError(msg) original_attacker = example_attackgraph.get_attacker_by_id( - loaded_attacker.id) - assert original_attacker, \ - f'Failed to find original attacker for id ' \ - '{loaded_attacker.id}.' + loaded_attacker.id + ) + assert original_attacker, ( + 'Failed to find original attacker for id {loaded_attacker.id}.' + ) loaded_attacker_dict = loaded_attacker.to_dict() original_attacker_dict = original_attacker.to_dict() assert loaded_attacker_dict == original_attacker_dict + def test_attackgraph_save_and_load_json_yml_model_given_without_attackers( - example_attackgraph: AttackGraph, - corelang_lang_graph: LanguageGraph - ): - attackgraph_save_and_load_json_yml_model_given( - example_attackgraph, - corelang_lang_graph, - False - ) + example_attackgraph: AttackGraph, corelang_lang_graph: LanguageGraph +) -> None: + attackgraph_save_and_load_json_yml_model_given( + example_attackgraph, corelang_lang_graph, False + ) + def test_attackgraph_save_and_load_json_yml_model_given_with_attackers( - example_attackgraph: AttackGraph, - corelang_lang_graph: LanguageGraph - ): - attackgraph_save_and_load_json_yml_model_given( - example_attackgraph, - corelang_lang_graph, - True - ) + example_attackgraph: AttackGraph, corelang_lang_graph: LanguageGraph +) -> None: + attackgraph_save_and_load_json_yml_model_given( + example_attackgraph, corelang_lang_graph, True + ) -def test_attackgraph_get_node_by_id(example_attackgraph: AttackGraph): - """Make sure get_node_by_id works as intended""" + +def test_attackgraph_get_node_by_id(example_attackgraph: AttackGraph) -> None: + """Make sure get_node_by_id works as intended.""" assert len(example_attackgraph.nodes) # make sure loop is run for node in example_attackgraph.nodes: if not isinstance(node.id, int): - raise ValueError(f'Invalid node id.') + msg = 'Invalid node id.' + raise ValueError(msg) get_node = example_attackgraph.get_node_by_id(node.id) assert get_node == node -def test_attackgraph_attach_attackers(example_attackgraph: AttackGraph): - """Make sure attackers are properly attached to graph""" - +def test_attackgraph_attach_attackers(example_attackgraph: AttackGraph) -> None: + """Make sure attackers are properly attached to graph.""" app1_ncu = example_attackgraph.get_node_by_full_name( 'Application 1:networkConnectUninspected' ) @@ -273,20 +256,20 @@ def test_attackgraph_attach_attackers(example_attackgraph: AttackGraph): assert app1_ncu in attacker.entry_points assert app1_ncu in attacker.reached_attack_steps - assert not app1_auv in attacker.entry_points - assert not app1_auv in attacker.reached_attack_steps + assert app1_auv not in attacker.entry_points + assert app1_auv not in attacker.reached_attack_steps attacker.compromise(app1_auv) assert app1_auv in attacker.reached_attack_steps assert app1_auv not in attacker.entry_points - for node in attacker.reached_attack_steps: # Make sure the Attacker is present on the nodes they have compromised assert attacker in node.compromised_by -def test_attackgraph_generate_graph(example_attackgraph: AttackGraph): - """Make sure the graph is correctly generated from model and lang""" + +def test_attackgraph_generate_graph(example_attackgraph: AttackGraph) -> None: + """Make sure the graph is correctly generated from model and lang.""" # TODO: Add test cases with defense steps # Empty the attack graph @@ -300,23 +283,22 @@ def test_attackgraph_generate_graph(example_attackgraph: AttackGraph): num_assets_attack_steps = 0 assert example_attackgraph.model for asset in example_attackgraph.model.assets: - attack_steps = example_attackgraph.\ - lang_graph._get_attacks_for_asset_type( - asset.type - ) + attack_steps = example_attackgraph.lang_graph._get_attacks_for_asset_type( + asset.type + ) num_assets_attack_steps += len(attack_steps) # Each attack step will get one node assert len(example_attackgraph.nodes) == num_assets_attack_steps -def test_attackgraph_according_to_corelang(corelang_lang_graph, model): +def test_attackgraph_according_to_corelang(corelang_lang_graph, model) -> None: """Looking at corelang .mal file, make sure the resulting - AttackGraph contains expected nodes""" - + AttackGraph contains expected nodes. + """ # Create 2 assets - app1 = create_application_asset(model, "Application 1") - app2 = create_application_asset(model, "Application 2") + app1 = create_application_asset(model, 'Application 1') + app2 = create_application_asset(model, 'Application 2') model.add_asset(app1) model.add_asset(app2) @@ -327,48 +309,77 @@ def test_attackgraph_according_to_corelang(corelang_lang_graph, model): # These are all attack 71 steps and defenses for Application asset in MAL expected_node_names_application = [ - "notPresent", "attemptUseVulnerability", "successfulUseVulnerability", - "useVulnerability", "attemptReverseReach", "successfulReverseReach", - "reverseReach", "localConnect", "networkConnectUninspected", - "networkConnectInspected", "networkConnect", - "specificAccessNetworkConnect", - "accessNetworkAndConnections", "attemptNetworkConnectFromResponse", - "networkConnectFromResponse", "specificAccessFromLocalConnection", - "specificAccessFromNetworkConnection", "specificAccess", - "bypassContainerization", "authenticate", - "specificAccessAuthenticate", "localAccess", "networkAccess", - "fullAccess", "physicalAccessAchieved", "attemptUnsafeUserActivity", - "successfulUnsafeUserActivity", "unsafeUserActivity", - "attackerUnsafeUserActivityCapability", - "attackerUnsafeUserActivityCapabilityWithReverseReach", - "attackerUnsafeUserActivityCapabilityWithoutReverseReach", - "supplyChainAuditing", "bypassSupplyChainAuditing", - "supplyChainAuditingBypassed", - "attemptFullAccessFromSupplyChainCompromise", - "fullAccessFromSupplyChainCompromise", - "attemptReadFromSoftProdVulnerability", - "attemptModifyFromSoftProdVulnerability", - "attemptDenyFromSoftProdVulnerability", "softwareCheck", - "softwareProductVulnerabilityLocalAccessAchieved", - "softwareProductVulnerabilityNetworkAccessAchieved", - "softwareProductVulnerabilityPhysicalAccessAchieved", - "softwareProductVulnerabilityLowPrivilegesAchieved", - "softwareProductVulnerabilityHighPrivilegesAchieved", - "softwareProductVulnerabilityUserInteractionAchieved", - "attemptSoftwareProductAbuse", - "softwareProductAbuse", "readFromSoftProdVulnerability", - "modifyFromSoftProdVulnerability", - "denyFromSoftProdVulnerability", - "attemptApplicationRespondConnectThroughData", - "successfulApplicationRespondConnectThroughData", - "applicationRespondConnectThroughData", - "attemptAuthorizedApplicationRespondConnectThroughData", - "successfulAuthorizedApplicationRespondConnectThroughData", - "authorizedApplicationRespondConnectThroughData", - "attemptRead", "successfulRead", "read", "specificAccessRead", - "attemptModify", "successfulModify", "modify", "specificAccessModify", - "attemptDeny", "successfulDeny", "deny", - "specificAccessDelete", "denyFromNetworkingAsset", "denyFromLockout" + 'notPresent', + 'attemptUseVulnerability', + 'successfulUseVulnerability', + 'useVulnerability', + 'attemptReverseReach', + 'successfulReverseReach', + 'reverseReach', + 'localConnect', + 'networkConnectUninspected', + 'networkConnectInspected', + 'networkConnect', + 'specificAccessNetworkConnect', + 'accessNetworkAndConnections', + 'attemptNetworkConnectFromResponse', + 'networkConnectFromResponse', + 'specificAccessFromLocalConnection', + 'specificAccessFromNetworkConnection', + 'specificAccess', + 'bypassContainerization', + 'authenticate', + 'specificAccessAuthenticate', + 'localAccess', + 'networkAccess', + 'fullAccess', + 'physicalAccessAchieved', + 'attemptUnsafeUserActivity', + 'successfulUnsafeUserActivity', + 'unsafeUserActivity', + 'attackerUnsafeUserActivityCapability', + 'attackerUnsafeUserActivityCapabilityWithReverseReach', + 'attackerUnsafeUserActivityCapabilityWithoutReverseReach', + 'supplyChainAuditing', + 'bypassSupplyChainAuditing', + 'supplyChainAuditingBypassed', + 'attemptFullAccessFromSupplyChainCompromise', + 'fullAccessFromSupplyChainCompromise', + 'attemptReadFromSoftProdVulnerability', + 'attemptModifyFromSoftProdVulnerability', + 'attemptDenyFromSoftProdVulnerability', + 'softwareCheck', + 'softwareProductVulnerabilityLocalAccessAchieved', + 'softwareProductVulnerabilityNetworkAccessAchieved', + 'softwareProductVulnerabilityPhysicalAccessAchieved', + 'softwareProductVulnerabilityLowPrivilegesAchieved', + 'softwareProductVulnerabilityHighPrivilegesAchieved', + 'softwareProductVulnerabilityUserInteractionAchieved', + 'attemptSoftwareProductAbuse', + 'softwareProductAbuse', + 'readFromSoftProdVulnerability', + 'modifyFromSoftProdVulnerability', + 'denyFromSoftProdVulnerability', + 'attemptApplicationRespondConnectThroughData', + 'successfulApplicationRespondConnectThroughData', + 'applicationRespondConnectThroughData', + 'attemptAuthorizedApplicationRespondConnectThroughData', + 'successfulAuthorizedApplicationRespondConnectThroughData', + 'authorizedApplicationRespondConnectThroughData', + 'attemptRead', + 'successfulRead', + 'read', + 'specificAccessRead', + 'attemptModify', + 'successfulModify', + 'modify', + 'specificAccessModify', + 'attemptDeny', + 'successfulDeny', + 'deny', + 'specificAccessDelete', + 'denyFromNetworkingAsset', + 'denyFromLockout', ] # Make sure the nodes in the AttackGraph have the expected names and order @@ -377,37 +388,35 @@ def test_attackgraph_according_to_corelang(corelang_lang_graph, model): # notPresent is a defense step and its children are (according to corelang): extected_children_of_not_present = [ - "successfulUseVulnerability", - "successfulReverseReach", - "networkConnectFromResponse", - "specificAccessFromLocalConnection", - "specificAccessFromNetworkConnection", - "localAccess", - "networkAccess", - "successfulUnsafeUserActivity", - "fullAccessFromSupplyChainCompromise", - "readFromSoftProdVulnerability", - "modifyFromSoftProdVulnerability", - "denyFromSoftProdVulnerability", - "successfulApplicationRespondConnectThroughData", - "successfulAuthorizedApplicationRespondConnectThroughData", - "successfulRead", - "successfulModify", - "successfulDeny" + 'successfulUseVulnerability', + 'successfulReverseReach', + 'networkConnectFromResponse', + 'specificAccessFromLocalConnection', + 'specificAccessFromNetworkConnection', + 'localAccess', + 'networkAccess', + 'successfulUnsafeUserActivity', + 'fullAccessFromSupplyChainCompromise', + 'readFromSoftProdVulnerability', + 'modifyFromSoftProdVulnerability', + 'denyFromSoftProdVulnerability', + 'successfulApplicationRespondConnectThroughData', + 'successfulAuthorizedApplicationRespondConnectThroughData', + 'successfulRead', + 'successfulModify', + 'successfulDeny', ] # Make sure children are also added for defense step notPresent - not_present_children = [ - n.name for n in attack_graph.nodes[0].children - ] + not_present_children = [n.name for n in attack_graph.nodes[0].children] assert not_present_children == extected_children_of_not_present -def test_attackgraph_regenerate_graph(): - """Make sure graph is regenerated""" - pass + +def test_attackgraph_regenerate_graph() -> None: + """Make sure graph is regenerated.""" -def test_attackgraph_remove_node(example_attackgraph: AttackGraph): - """Make sure nodes are removed correctly""" +def test_attackgraph_remove_node(example_attackgraph: AttackGraph) -> None: + """Make sure nodes are removed correctly.""" node_to_remove = example_attackgraph.nodes[10] parents = list(node_to_remove.parents) children = list(node_to_remove.children) @@ -423,9 +432,8 @@ def test_attackgraph_remove_node(example_attackgraph: AttackGraph): assert node_to_remove not in child.parents -def test_attackgraph_deepcopy(example_attackgraph: AttackGraph): - """ - Try to deepcopy an attackgraph object. The nodes of the attack graph +def test_attackgraph_deepcopy(example_attackgraph: AttackGraph) -> None: + """Try to deepcopy an attackgraph object. The nodes of the attack graph and attackers should be duplicated into new objects, while references to the instance model should remain the same. """ @@ -440,19 +448,21 @@ def test_attackgraph_deepcopy(example_attackgraph: AttackGraph): assert len(copied_attackgraph.nodes) == len(example_attackgraph.nodes) - assert list(copied_attackgraph._id_to_node.keys()) \ - == list(example_attackgraph._id_to_node.keys()) + assert list(copied_attackgraph._id_to_node.keys()) == list( + example_attackgraph._id_to_node.keys() + ) - assert list(copied_attackgraph._id_to_attacker.keys()) \ - == list(example_attackgraph._id_to_attacker.keys()) + assert list(copied_attackgraph._id_to_attacker.keys()) == list( + example_attackgraph._id_to_attacker.keys() + ) - assert list(copied_attackgraph._full_name_to_node.keys()) \ - == list(example_attackgraph._full_name_to_node.keys()) + assert list(copied_attackgraph._full_name_to_node.keys()) == list( + example_attackgraph._full_name_to_node.keys() + ) assert id(copied_attackgraph.model) == id(example_attackgraph.model) - assert len(copied_attackgraph.nodes) \ - == len(example_attackgraph.nodes) + assert len(copied_attackgraph.nodes) == len(example_attackgraph.nodes) for node in copied_attackgraph.nodes: assert node.id is not None @@ -482,16 +492,15 @@ def test_attackgraph_deepcopy(example_attackgraph: AttackGraph): attack_graph_child = copied_attackgraph.get_node_by_id(child.id) assert id(attack_graph_child) == id(child) - assert len(copied_attackgraph.attackers) \ - == len(example_attackgraph.attackers) - assert id(copied_attackgraph.attackers) \ - != id(example_attackgraph.attackers) + assert len(copied_attackgraph.attackers) == len(example_attackgraph.attackers) + assert id(copied_attackgraph.attackers) != id(example_attackgraph.attackers) for attacker in copied_attackgraph.attackers: - for entry_point in attacker.entry_points: assert entry_point.id - entry_point_in_attack_graph = copied_attackgraph.get_node_by_id(entry_point.id) + entry_point_in_attack_graph = copied_attackgraph.get_node_by_id( + entry_point.id + ) assert entry_point_in_attack_graph assert entry_point == entry_point_in_attack_graph assert id(entry_point) == id(entry_point_in_attack_graph) @@ -503,10 +512,8 @@ def test_attackgraph_deepcopy(example_attackgraph: AttackGraph): assert original_attacker.to_dict() == attacker.to_dict() -def test_attackgraph_deepcopy_attackers(example_attackgraph: AttackGraph): - """ - Make sure attackers entry points and reached steps are copied correctly - """ +def test_attackgraph_deepcopy_attackers(example_attackgraph: AttackGraph) -> None: + """Make sure attackers entry points and reached steps are copied correctly.""" example_attackgraph.attach_attackers() original_attacker = example_attackgraph.attackers[0] @@ -537,10 +544,8 @@ def test_attackgraph_deepcopy_attackers(example_attackgraph: AttackGraph): assert id(node) == id(entrypoint) -def test_deepcopy_memo_test(example_attackgraph: AttackGraph): - """ - Make sure memo is filled up with expected number of objects - """ +def test_deepcopy_memo_test(example_attackgraph: AttackGraph) -> None: + """Make sure memo is filled up with expected number of objects.""" example_attackgraph.attach_attackers() memo: dict = {} @@ -556,57 +561,58 @@ def test_deepcopy_memo_test(example_attackgraph: AttackGraph): # Make sure memo stored all of the attackers memo_attackers = [o for o in memo.values() if isinstance(o, Attacker)] - assert len(copied_attackers) == len(memo_attackers) == len(example_attackgraph.attackers) + assert ( + len(copied_attackers) + == len(memo_attackers) + == len(example_attackgraph.attackers) + ) # Make sure memo didn't store any new nodes memo_nodes = [o for o in memo.values() if isinstance(o, AttackGraphNode)] assert len(memo_nodes) == len(example_attackgraph.nodes) -def test_attackgraph_subtype(): - test_lang_graph = LanguageGraph(MalCompiler().compile( - 'tests/testdata/subtype_attack_step.mal')) +def test_attackgraph_subtype() -> None: + test_lang_graph = LanguageGraph( + MalCompiler().compile('tests/testdata/subtype_attack_step.mal') + ) lang_classes_factory = LanguageClassesFactory(test_lang_graph) test_model = Model('Test Model', lang_classes_factory) # Create assets - baseasset1 = lang_classes_factory.get_asset_class('BaseAsset')( - name = 'BaseAsset 1') + baseasset1 = lang_classes_factory.get_asset_class('BaseAsset')(name='BaseAsset 1') - subasset1 = lang_classes_factory.get_asset_class('SubAsset')( - name = 'SubAsset 1') + subasset1 = lang_classes_factory.get_asset_class('SubAsset')(name='SubAsset 1') otherasset1 = lang_classes_factory.get_asset_class('OtherAsset')( - name = 'OtherAsset 1') + name='OtherAsset 1' + ) test_model.add_asset(baseasset1) test_model.add_asset(subasset1) test_model.add_asset(otherasset1) # Create association between subasset1 and otherasset1 - assoc = create_association(test_model, - left_assets = [subasset1, baseasset1], - right_assets = [otherasset1], - assoc_type = 'SubtypeTestAssoc', - left_fieldname = 'field1', - right_fieldname = 'field2') + assoc = create_association( + test_model, + left_assets=[subasset1, baseasset1], + right_assets=[otherasset1], + assoc_type='SubtypeTestAssoc', + left_fieldname='field1', + right_fieldname='field2', + ) test_model.add_association(assoc) - test_attack_graph = AttackGraph( - lang_graph=test_lang_graph, - model=test_model - ) - ba_1_base_step1 = test_attack_graph.get_node_by_full_name( - 'BaseAsset 1:base_step1') - ba_1_base_step2 = test_attack_graph.get_node_by_full_name( - 'BaseAsset 1:base_step2') - sa_1_base_step1 = test_attack_graph.get_node_by_full_name( - 'SubAsset 1:base_step1') - sa_1_base_step2 = test_attack_graph.get_node_by_full_name( - 'SubAsset 1:base_step2') + test_attack_graph = AttackGraph(lang_graph=test_lang_graph, model=test_model) + ba_1_base_step1 = test_attack_graph.get_node_by_full_name('BaseAsset 1:base_step1') + ba_1_base_step2 = test_attack_graph.get_node_by_full_name('BaseAsset 1:base_step2') + sa_1_base_step1 = test_attack_graph.get_node_by_full_name('SubAsset 1:base_step1') + sa_1_base_step2 = test_attack_graph.get_node_by_full_name('SubAsset 1:base_step2') sa_1_subasset_step1 = test_attack_graph.get_node_by_full_name( - 'SubAsset 1:subasset_step1') + 'SubAsset 1:subasset_step1' + ) oa_1_other_step1 = test_attack_graph.get_node_by_full_name( - 'OtherAsset 1:other_step1') + 'OtherAsset 1:other_step1' + ) assert ba_1_base_step1 in oa_1_other_step1.children assert ba_1_base_step2 not in oa_1_other_step1.children @@ -614,23 +620,26 @@ def test_attackgraph_subtype(): assert sa_1_base_step2 in oa_1_other_step1.children assert sa_1_subasset_step1 in oa_1_other_step1.children -def test_attackgraph_setops(): - test_lang_graph = LanguageGraph(MalCompiler().compile( - 'tests/testdata/set_ops.mal')) +def test_attackgraph_setops() -> None: + test_lang_graph = LanguageGraph(MalCompiler().compile('tests/testdata/set_ops.mal')) lang_classes_factory = LanguageClassesFactory(test_lang_graph) test_model = Model('Test Model', lang_classes_factory) # Create assets set_ops_a1 = lang_classes_factory.get_asset_class('SetOpsAssetA')( - name = 'SetOpsAssetA 1') + name='SetOpsAssetA 1' + ) set_ops_b1 = lang_classes_factory.get_asset_class('SetOpsAssetB')( - name = 'SetOpsAssetB 1') + name='SetOpsAssetB 1' + ) set_ops_b2 = lang_classes_factory.get_asset_class('SetOpsAssetB')( - name = 'SetOpsAssetB 2') + name='SetOpsAssetB 2' + ) set_ops_b3 = lang_classes_factory.get_asset_class('SetOpsAssetB')( - name = 'SetOpsAssetB 3') + name='SetOpsAssetB 3' + ) test_model.add_asset(set_ops_a1) test_model.add_asset(set_ops_b1) @@ -638,47 +647,58 @@ def test_attackgraph_setops(): test_model.add_asset(set_ops_b3) # Create association - assoc = create_association(test_model, - left_assets = [set_ops_a1], - right_assets = [set_ops_b1, set_ops_b2], - assoc_type = 'SetOps1', - left_fieldname = 'fieldA1', - right_fieldname = 'fieldB1') + assoc = create_association( + test_model, + left_assets=[set_ops_a1], + right_assets=[set_ops_b1, set_ops_b2], + assoc_type='SetOps1', + left_fieldname='fieldA1', + right_fieldname='fieldB1', + ) test_model.add_association(assoc) - assoc = create_association(test_model, - left_assets = [set_ops_a1], - right_assets = [set_ops_b2, set_ops_b3], - assoc_type = 'SetOps2', - left_fieldname = 'fieldA2', - right_fieldname = 'fieldB2') + assoc = create_association( + test_model, + left_assets=[set_ops_a1], + right_assets=[set_ops_b2, set_ops_b3], + assoc_type='SetOps2', + left_fieldname='fieldA2', + right_fieldname='fieldB2', + ) test_model.add_association(assoc) - test_attack_graph = AttackGraph( - lang_graph=test_lang_graph, - model=test_model - ) + test_attack_graph = AttackGraph(lang_graph=test_lang_graph, model=test_model) assetA1_opsA = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetA 1:testStepSetOpsA') + 'SetOpsAssetA 1:testStepSetOpsA' + ) assetB1_opsB1 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 1:testStepSetOpsB1') + 'SetOpsAssetB 1:testStepSetOpsB1' + ) assetB1_opsB2 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 1:testStepSetOpsB2') + 'SetOpsAssetB 1:testStepSetOpsB2' + ) assetB1_opsB3 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 1:testStepSetOpsB3') + 'SetOpsAssetB 1:testStepSetOpsB3' + ) assetB2_opsB1 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 2:testStepSetOpsB1') + 'SetOpsAssetB 2:testStepSetOpsB1' + ) assetB2_opsB2 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 2:testStepSetOpsB2') + 'SetOpsAssetB 2:testStepSetOpsB2' + ) assetB2_opsB3 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 2:testStepSetOpsB3') + 'SetOpsAssetB 2:testStepSetOpsB3' + ) assetB3_opsB1 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 3:testStepSetOpsB1') + 'SetOpsAssetB 3:testStepSetOpsB1' + ) assetB3_opsB2 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 3:testStepSetOpsB2') + 'SetOpsAssetB 3:testStepSetOpsB2' + ) assetB3_opsB3 = test_attack_graph.get_node_by_full_name( - 'SetOpsAssetB 3:testStepSetOpsB3') + 'SetOpsAssetB 3:testStepSetOpsB3' + ) assert assetB1_opsB1 in assetA1_opsA.children assert assetB1_opsB2 not in assetA1_opsA.children @@ -690,24 +710,20 @@ def test_attackgraph_setops(): assert assetB3_opsB2 not in assetA1_opsA.children assert assetB3_opsB3 not in assetA1_opsA.children -def test_attackgraph_transitive(): - test_lang_graph = LanguageGraph(MalCompiler().compile( - 'tests/testdata/transitive.mal')) + +def test_attackgraph_transitive() -> None: + test_lang_graph = LanguageGraph( + MalCompiler().compile('tests/testdata/transitive.mal') + ) lang_classes_factory = LanguageClassesFactory(test_lang_graph) test_model = Model('Test Model', lang_classes_factory) - asset1 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 1') - asset2 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 2') - asset3 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 3') - asset4 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 4') - asset5 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 5') - asset6 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 6') + asset1 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 1') + asset2 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 2') + asset3 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 3') + asset4 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 4') + asset5 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 5') + asset6 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 6') test_model.add_asset(asset1) test_model.add_asset(asset2) @@ -716,63 +732,64 @@ def test_attackgraph_transitive(): test_model.add_asset(asset5) test_model.add_asset(asset6) - assoc12 = create_association(test_model, - left_assets = [asset1], - right_assets = [asset2], - assoc_type = 'TransitiveTestAssoc', - left_fieldname = 'field1', - right_fieldname = 'field2') + assoc12 = create_association( + test_model, + left_assets=[asset1], + right_assets=[asset2], + assoc_type='TransitiveTestAssoc', + left_fieldname='field1', + right_fieldname='field2', + ) test_model.add_association(assoc12) - assoc23 = create_association(test_model, - left_assets = [asset2], - right_assets = [asset3], - assoc_type = 'TransitiveTestAssoc', - left_fieldname = 'field1', - right_fieldname = 'field2') + assoc23 = create_association( + test_model, + left_assets=[asset2], + right_assets=[asset3], + assoc_type='TransitiveTestAssoc', + left_fieldname='field1', + right_fieldname='field2', + ) test_model.add_association(assoc23) - assoc34 = create_association(test_model, - left_assets = [asset3], - right_assets = [asset4], - assoc_type = 'TransitiveTestAssoc', - left_fieldname = 'field1', - right_fieldname = 'field2') + assoc34 = create_association( + test_model, + left_assets=[asset3], + right_assets=[asset4], + assoc_type='TransitiveTestAssoc', + left_fieldname='field1', + right_fieldname='field2', + ) test_model.add_association(assoc34) - assoc35 = create_association(test_model, - left_assets = [asset3], - right_assets = [asset5], - assoc_type = 'TransitiveTestAssoc', - left_fieldname = 'field1', - right_fieldname = 'field2') + assoc35 = create_association( + test_model, + left_assets=[asset3], + right_assets=[asset5], + assoc_type='TransitiveTestAssoc', + left_fieldname='field1', + right_fieldname='field2', + ) test_model.add_association(assoc35) - assoc61 = create_association(test_model, - left_assets = [asset6], - right_assets = [asset1], - assoc_type = 'TransitiveTestAssoc', - left_fieldname = 'field1', - right_fieldname = 'field2') + assoc61 = create_association( + test_model, + left_assets=[asset6], + right_assets=[asset1], + assoc_type='TransitiveTestAssoc', + left_fieldname='field1', + right_fieldname='field2', + ) test_model.add_association(assoc61) - test_attack_graph = AttackGraph( - lang_graph=test_lang_graph, - model=test_model - ) + test_attack_graph = AttackGraph(lang_graph=test_lang_graph, model=test_model) - asset1_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 1:test_step') - asset2_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 2:test_step') - asset3_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 3:test_step') - asset4_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 4:test_step') - asset5_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 5:test_step') - asset6_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 6:test_step') + asset1_test_step = test_attack_graph.get_node_by_full_name('TestAsset 1:test_step') + asset2_test_step = test_attack_graph.get_node_by_full_name('TestAsset 2:test_step') + asset3_test_step = test_attack_graph.get_node_by_full_name('TestAsset 3:test_step') + asset4_test_step = test_attack_graph.get_node_by_full_name('TestAsset 4:test_step') + asset5_test_step = test_attack_graph.get_node_by_full_name('TestAsset 5:test_step') + asset6_test_step = test_attack_graph.get_node_by_full_name('TestAsset 6:test_step') assert asset1_test_step in asset1_test_step.children assert asset2_test_step in asset1_test_step.children @@ -817,60 +834,54 @@ def test_attackgraph_transitive(): assert asset6_test_step in asset6_test_step.children -def test_attackgraph_transitive_advanced(): +def test_attackgraph_transitive_advanced() -> None: # TODO: Improve this test to actually use more complex transitive # relationships. Right now it is just the asset and any direct # associations it may have. - test_lang_graph = LanguageGraph(MalCompiler().compile( - 'tests/testdata/transitive_advanced.mal')) + test_lang_graph = LanguageGraph( + MalCompiler().compile('tests/testdata/transitive_advanced.mal') + ) test_lang_graph.save_to_file('tmp/trans_adv_lang_graph.yml') lang_classes_factory = LanguageClassesFactory(test_lang_graph) test_model = Model('Test Model', lang_classes_factory) - asset1 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 1') - asset2 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 2') - asset3 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 3') - asset4 = lang_classes_factory.get_asset_class('TestAsset')( - name = 'TestAsset 4') + asset1 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 1') + asset2 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 2') + asset3 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 3') + asset4 = lang_classes_factory.get_asset_class('TestAsset')(name='TestAsset 4') test_model.add_asset(asset1) test_model.add_asset(asset2) test_model.add_asset(asset3) test_model.add_asset(asset4) - assocA = create_association(test_model, - left_assets = [asset1], - right_assets = [asset2, asset3], - assoc_type = 'TransitiveTestAssocA', - left_fieldname = 'fieldA1', - right_fieldname = 'fieldA2') + assocA = create_association( + test_model, + left_assets=[asset1], + right_assets=[asset2, asset3], + assoc_type='TransitiveTestAssocA', + left_fieldname='fieldA1', + right_fieldname='fieldA2', + ) test_model.add_association(assocA) - assocB = create_association(test_model, - left_assets = [asset1], - right_assets = [asset3, asset4], - assoc_type = 'TransitiveTestAssocB', - left_fieldname = 'fieldB1', - right_fieldname = 'fieldB2') + assocB = create_association( + test_model, + left_assets=[asset1], + right_assets=[asset3, asset4], + assoc_type='TransitiveTestAssocB', + left_fieldname='fieldB1', + right_fieldname='fieldB2', + ) test_model.add_association(assocB) - test_attack_graph = AttackGraph( - lang_graph=test_lang_graph, - model=test_model - ) + test_attack_graph = AttackGraph(lang_graph=test_lang_graph, model=test_model) - asset1_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 1:test_step') - asset2_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 2:test_step') - asset3_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 3:test_step') - asset4_test_step = test_attack_graph.get_node_by_full_name( - 'TestAsset 4:test_step') + asset1_test_step = test_attack_graph.get_node_by_full_name('TestAsset 1:test_step') + asset2_test_step = test_attack_graph.get_node_by_full_name('TestAsset 2:test_step') + asset3_test_step = test_attack_graph.get_node_by_full_name('TestAsset 3:test_step') + asset4_test_step = test_attack_graph.get_node_by_full_name('TestAsset 4:test_step') assert asset1_test_step in asset1_test_step.children assert asset2_test_step not in asset1_test_step.children diff --git a/tests/attackgraph/test_node.py b/tests/attackgraph/test_node.py index c6255746..a088444b 100644 --- a/tests/attackgraph/test_node.py +++ b/tests/attackgraph/test_node.py @@ -1,55 +1,56 @@ -"""Unit tests for AttackGraphNode functionality""" +"""Unit tests for AttackGraphNode functionality.""" -from maltoolbox.attackgraph.node import AttackGraphNode from maltoolbox.attackgraph.attacker import Attacker from maltoolbox.attackgraph.attackgraph import AttackGraph +from maltoolbox.attackgraph.node import AttackGraphNode from maltoolbox.language import LanguageGraph -def test_attackgraphnode(dummy_lang_graph: LanguageGraph): - r"""Create a graph from nodes - node1 - / \ +def test_attackgraphnode(dummy_lang_graph: LanguageGraph) -> None: + r"""Create a graph from nodes. + + node1 + / \ node2 node3 - / \ / \ + / \ / \ node4 node5 node6 """ - - dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].\ - attack_steps['DummyAttackStep'] + dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].attack_steps[ + 'DummyAttackStep' + ] # Create a graph of nodes according to above diagram node1 = AttackGraphNode( - type = "or", - name = "node1", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node1', + lang_graph_attack_step=dummy_attack_step, ) node2 = AttackGraphNode( - type = "defense", - name = "node2", - lang_graph_attack_step = dummy_attack_step, - defense_status=1.0 + type='defense', + name='node2', + lang_graph_attack_step=dummy_attack_step, + defense_status=1.0, ) node3 = AttackGraphNode( - type = "defense", - name = "node3", - lang_graph_attack_step = dummy_attack_step, - defense_status=0.0 + type='defense', + name='node3', + lang_graph_attack_step=dummy_attack_step, + defense_status=0.0, ) node4 = AttackGraphNode( - type = "or", - name = "node4", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node4', + lang_graph_attack_step=dummy_attack_step, ) node5 = AttackGraphNode( - type = "and", - name = "node5", - lang_graph_attack_step = dummy_attack_step, + type='and', + name='node5', + lang_graph_attack_step=dummy_attack_step, ) node6 = AttackGraphNode( - type = "or", - name = "node6", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node6', + lang_graph_attack_step=dummy_attack_step, ) node1.children = [node2, node3] @@ -63,9 +64,7 @@ def test_attackgraphnode(dummy_lang_graph: LanguageGraph): # Make sure compromised node has attacker added to it attacker = Attacker( - name = "Test Attacker", - entry_points = [node1], - reached_attack_steps = [] + name='Test Attacker', entry_points=[node1], reached_attack_steps=[] ) attack_graph = AttackGraph(dummy_lang_graph) diff --git a/tests/attackgraph/test_query.py b/tests/attackgraph/test_query.py index f3541aec..4d386a8d 100644 --- a/tests/attackgraph/test_query.py +++ b/tests/attackgraph/test_query.py @@ -1,29 +1,26 @@ -"""Unit tests for AttackGraph functionality""" +"""Unit tests for AttackGraph functionality.""" -from maltoolbox.attackgraph import AttackGraphNode, Attacker, AttackGraph -from maltoolbox.language import LanguageGraph +from maltoolbox.attackgraph import Attacker, AttackGraph, AttackGraphNode from maltoolbox.attackgraph.query import ( is_node_traversable_by_attacker, ) +from maltoolbox.language import LanguageGraph -def test_query_is_node_traversable_by_attacker(dummy_lang_graph: LanguageGraph): - """Make sure it returns True or False when expected""" - dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].\ - attack_steps['DummyAttackStep'] +def test_query_is_node_traversable_by_attacker(dummy_lang_graph: LanguageGraph) -> None: + """Make sure it returns True or False when expected.""" + dummy_attack_step = dummy_lang_graph.assets['DummyAsset'].attack_steps[ + 'DummyAttackStep' + ] # An attacker with no meaningful data - attacker = Attacker( - name = "Test Attacker", - entry_points = [], - reached_attack_steps = [] - ) + attacker = Attacker(name='Test Attacker', entry_points=[], reached_attack_steps=[]) # Node1 should be traversable since node type is OR node1 = AttackGraphNode( - type = "or", - name = "node1", - lang_graph_attack_step = dummy_attack_step, + type='or', + name='node1', + lang_graph_attack_step=dummy_attack_step, ) attack_graph = AttackGraph(dummy_lang_graph) @@ -34,9 +31,9 @@ def test_query_is_node_traversable_by_attacker(dummy_lang_graph: LanguageGraph): # Node2 should be traversable since node has no parents node2 = AttackGraphNode( - type = "and", - name = "node2", - lang_graph_attack_step = dummy_attack_step, + type='and', + name='node2', + lang_graph_attack_step=dummy_attack_step, ) attack_graph.add_node(node2) traversable = is_node_traversable_by_attacker(node2, attacker) @@ -45,14 +42,14 @@ def test_query_is_node_traversable_by_attacker(dummy_lang_graph: LanguageGraph): # Node 4 should not be traversable since node has type AND # and it has two parents that are not compromised by attacker node3 = AttackGraphNode( - type = "and", - name = "node3", - lang_graph_attack_step = dummy_attack_step, + type='and', + name='node3', + lang_graph_attack_step=dummy_attack_step, ) node4 = AttackGraphNode( - type = "and", - name = "node4", - lang_graph_attack_step = dummy_attack_step, + type='and', + name='node4', + lang_graph_attack_step=dummy_attack_step, ) node4.parents = [node2, node3] node2.children = [node4] diff --git a/tests/conftest.py b/tests/conftest.py index 9087bd23..bc023a53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,45 +1,53 @@ -"""Fixtures and helpers used in several test modules""" +"""Fixtures and helpers used in several test modules.""" + import os + import pytest -from maltoolbox.language import (LanguageGraph, LanguageGraphAsset, - LanguageGraphAttackStep, LanguageClassesFactory) +from maltoolbox.language import ( + LanguageClassesFactory, + LanguageGraph, + LanguageGraphAsset, + LanguageGraphAttackStep, +) from maltoolbox.model import Model +# Helpers -## Helpers def path_testdata(filename): - """Returns the absolute path of a test data file (in ./testdata) + """Returns the absolute path of a test data file (in ./testdata). Arguments: filename - filename to append to path of ./testdata + """ current_dir = os.path.dirname(os.path.realpath(__file__)) - return os.path.join(current_dir, f"testdata/{filename}") + return os.path.join(current_dir, f'testdata/{filename}') def empty_model(name, lang_classes_factory): - """Fixture that generates a model for tests + """Fixture that generates a model for tests. Uses coreLang specification (fixture) to create and return Model """ - # Create instance model from model json file return Model(name, lang_classes_factory) -## Fixtures (can be ingested into tests) + +# Fixtures (can be ingested into tests) + @pytest.fixture def corelang_lang_graph(): - """Fixture that returns the coreLang language specification as dict""" - mar_file_path = path_testdata("org.mal-lang.coreLang-1.0.0.mar") + """Fixture that returns the coreLang language specification as dict.""" + mar_file_path = path_testdata('org.mal-lang.coreLang-1.0.0.mar') return LanguageGraph.from_mar_archive(mar_file_path) @pytest.fixture def model(corelang_lang_graph): - """Fixture that generates a model for tests + """Fixture that generates a model for tests. Uses coreLang specification (fixture) to create and return a Model object with no assets or associations @@ -53,19 +61,14 @@ def model(corelang_lang_graph): @pytest.fixture def dummy_lang_graph(corelang_lang_graph): """Fixture that generates a dummy LanguageGraph with a dummy - LanguageGraphAsset and LanguageGraphAttackStep + LanguageGraphAsset and LanguageGraphAttackStep. """ lang_graph = LanguageGraph() - dummy_asset = LanguageGraphAsset( - name = 'DummyAsset' - ) + dummy_asset = LanguageGraphAsset(name='DummyAsset') lang_graph.assets['DummyAsset'] = dummy_asset dummy_attack_step_node = LanguageGraphAttackStep( - name = 'DummyAttackStep', - type = 'or', - asset = dummy_asset + name='DummyAttackStep', type='or', asset=dummy_asset ) dummy_asset.attack_steps['DummyAttackStep'] = dummy_attack_step_node - return lang_graph diff --git a/tests/language/test_classes_factory.py b/tests/language/test_classes_factory.py index 70623c9c..0351267e 100644 --- a/tests/language/test_classes_factory.py +++ b/tests/language/test_classes_factory.py @@ -1,20 +1,22 @@ -"""Tests for the LanguageGraph""" +"""Tests for the LanguageGraph.""" -import pytest -from conftest import path_testdata +from maltoolbox.language import LanguageClassesFactory, LanguageGraph -from maltoolbox.language import LanguageGraph, LanguageClassesFactory -def test_corelang_classes_factory(corelang_lang_graph: LanguageGraph): - """ Test to see if the LanguageClassesFactory is properly generated based +def test_corelang_classes_factory(corelang_lang_graph: LanguageGraph) -> None: + """Test to see if the LanguageClassesFactory is properly generated based on coreLang Language Graph. """ # Init LanguageClassesFactory lang_classes_factory = LanguageClassesFactory(corelang_lang_graph) assert hasattr(lang_classes_factory.ns, 'Asset_Application') - assert hasattr(lang_classes_factory.ns, 'Association_ApplicationVulnerability_vulnerabilities_application') + assert hasattr( + lang_classes_factory.ns, + 'Association_ApplicationVulnerability_vulnerabilities_application', + ) -def test_create_asset(corelang_lang_graph: LanguageGraph): + +def test_create_asset(corelang_lang_graph: LanguageGraph) -> None: # Init LanguageClassesFactory - lang_classes_factory = LanguageClassesFactory(corelang_lang_graph) + LanguageClassesFactory(corelang_lang_graph) diff --git a/tests/language/test_languagegraph.py b/tests/language/test_languagegraph.py index 6604c119..5809046c 100644 --- a/tests/language/test_languagegraph.py +++ b/tests/language/test_languagegraph.py @@ -1,47 +1,46 @@ -"""Tests for the LanguageGraph""" +"""Tests for the LanguageGraph.""" -import pytest from conftest import path_testdata from maltoolbox.language import LanguageGraph - from maltoolbox.language.compiler import MalCompiler -from maltoolbox.language import LanguageGraph -def test_languagegraph_save_load(corelang_lang_graph: LanguageGraph): +def test_languagegraph_save_load(corelang_lang_graph: LanguageGraph) -> None: """Test to see if saving and loading a language graph to a file produces the same language graph. We have to use the json format to save and load - because YAML reorders the keys in alphabetical order.""" - graph_path = "/tmp/langgraph.json" + because YAML reorders the keys in alphabetical order. + """ + graph_path = '/tmp/langgraph.json' corelang_lang_graph.save_to_file(graph_path) new_lang_graph = LanguageGraph.load_from_file(graph_path) assert new_lang_graph._to_dict() == corelang_lang_graph._to_dict() + # TODO: Replace this with a dedicated test that just checks for union for # assets with the same super asset -def test_corelang_with_union_different_assets_same_super_asset(): +def test_corelang_with_union_different_assets_same_super_asset() -> None: """Uses modified coreLang language specification. An attackstep in IAMObject will contain a union between Identity and Group, which should be allowed, since they share the same super asset. """ - - mar_file_path = path_testdata("corelang-union-common-ancestor.mar") + mar_file_path = path_testdata('corelang-union-common-ancestor.mar') # Make sure that it can generate LanguageGraph.from_mar_archive(mar_file_path) -def test_interleaved_vars(): + +def test_interleaved_vars() -> None: """Check to see if two interleaved variables(variables that contain variables from each other, A2 contains B1 and B2 contains A1) were resolved correct. """ - - test_lang_graph = LanguageGraph(MalCompiler().compile( - 'tests/testdata/interleaved_vars.mal')) + test_lang_graph = LanguageGraph( + MalCompiler().compile('tests/testdata/interleaved_vars.mal') + ) assert 'AssetA' in test_lang_graph.assets assert 'AssetB' in test_lang_graph.assets @@ -60,12 +59,15 @@ def test_interleaved_vars(): assert varB2[0] == assetB assert varB2[1].right_link.fieldname == 'fieldB' -def test_inherited_vars(): + +def test_inherited_vars() -> None: LanguageGraph(MalCompiler().compile('tests/testdata/inherited_vars.mal')) -def test_attackstep_override(): - test_lang_graph = LanguageGraph(MalCompiler().compile( - 'tests/testdata/attackstep_override.mal')) + +def test_attackstep_override() -> None: + test_lang_graph = LanguageGraph( + MalCompiler().compile('tests/testdata/attackstep_override.mal') + ) assert 'EmptyParent' in test_lang_graph.assets assert 'Child1' in test_lang_graph.assets @@ -144,6 +146,7 @@ def test_attackstep_override(): assert fc_target3.full_name in fc_attackstep.children assert fc_target4.full_name in fc_attackstep.children + # TODO: Re-enable this test once the compiler and language are compatible with # one another. # def test_mallib_mal(): diff --git a/tests/test_model.py b/tests/test_model.py index 8a4caa6c..8f34df39 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,53 +1,51 @@ -"""Unit tests for maltoolbox.model""" +"""Unit tests for maltoolbox.model.""" import pytest +from conftest import path_testdata from python_jsonschema_objects.validators import ValidationError -from conftest import empty_model, path_testdata -from maltoolbox.model import Model, AttackerAttachment -from maltoolbox.exceptions import ModelAssociationException, DuplicateModelAssociationError +from maltoolbox.exceptions import ( + DuplicateModelAssociationError, + ModelAssociationException, +) +from maltoolbox.model import AttackerAttachment, Model -### Helper functions +# Helper functions -APP_EXEC_ASSOC_NAME = "AppExecution" -DATA_CONTAIN_ASSOC_NAME = "DataContainment" +APP_EXEC_ASSOC_NAME = 'AppExecution' +DATA_CONTAIN_ASSOC_NAME = 'DataContainment' -def create_application_asset(model, name): - """Helper function to create an asset of coreLang type Application""" - return model.lang_classes_factory.get_asset_class('Application')( - name = name) +def create_application_asset(model, name): + """Helper function to create an asset of coreLang type Application.""" + return model.lang_classes_factory.get_asset_class('Application')(name=name) def create_data_asset(model, name): - """Helper function to create an asset of coreLang type Data""" - - return model.lang_classes_factory.get_asset_class('Data')(name = name) + """Helper function to create an asset of coreLang type Data.""" + return model.lang_classes_factory.get_asset_class('Data')(name=name) def create_association( - model, - left_assets, - right_assets, - assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", - right_fieldname="appExecutedApps", - ): + model, + left_assets, + right_assets, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', +): """Helper function to create an association dict with - given parameters, useful in tests""" - + given parameters, useful in tests. + """ # Simulate receiving the association from a json file association_dict = { - assoc_type: { - left_fieldname: left_assets, - right_fieldname: right_assets - } + assoc_type: {left_fieldname: left_assets, right_fieldname: right_assets} } # Create the association using the lang_classes_factory - association = model.lang_classes_factory.\ - get_association_class_by_fieldnames( - assoc_type, left_fieldname, right_fieldname)() + association = model.lang_classes_factory.get_association_class_by_fieldnames( + assoc_type, left_fieldname, right_fieldname + )() # Add the assets for field, assets in association_dict[assoc_type].items(): @@ -55,13 +53,15 @@ def create_association( return association -### Tests -def test_attacker_attachment_add_entry_point(model: Model): +# Tests + + +def test_attacker_attachment_add_entry_point(model: Model) -> None: """""" - asset1 = create_application_asset(model, "Asset1") - asset2 = create_application_asset(model, "Asset2") + asset1 = create_application_asset(model, 'Asset1') + asset2 = create_application_asset(model, 'Asset2') model.add_asset(asset1) model.add_asset(asset2) @@ -91,11 +91,11 @@ def test_attacker_attachment_add_entry_point(model: Model): assert attacker1.entry_points[1][1] == ['access'] -def test_attacker_attachment_remove_entry_point(model: Model): +def test_attacker_attachment_remove_entry_point(model: Model) -> None: """""" - asset1 = create_application_asset(model, "Asset1") - asset2 = create_application_asset(model, "Asset2") + asset1 = create_application_asset(model, 'Asset1') + asset2 = create_application_asset(model, 'Asset2') model.add_asset(asset1) model.add_asset(asset2) @@ -145,11 +145,11 @@ def test_attacker_attachment_remove_entry_point(model: Model): assert len(attacker1.entry_points) == 0 -def test_attacker_attachment_remove_asset(model: Model): +def test_attacker_attachment_remove_asset(model: Model) -> None: """""" - asset1 = create_application_asset(model, "Asset1") - asset2 = create_application_asset(model, "Asset2") + asset1 = create_application_asset(model, 'Asset1') + asset2 = create_application_asset(model, 'Asset2') model.add_asset(asset1) model.add_asset(asset2) @@ -201,10 +201,10 @@ def test_attacker_attachment_remove_asset(model: Model): assert attacker2.entry_points[0][1] == ['read'] -def test_add_remove_attacker(model: Model): +def test_add_remove_attacker(model: Model) -> None: """""" - asset1 = create_application_asset(model, "Asset1") + asset1 = create_application_asset(model, 'Asset1') model.add_asset(asset1) attacker1 = AttackerAttachment() @@ -222,9 +222,9 @@ def test_add_remove_attacker(model: Model): model.remove_attacker(attacker2) assert len(model.attackers) == 0 -def test_model_add_asset(model: Model): - """Make sure assets are added correctly""" +def test_model_add_asset(model: Model) -> None: + """Make sure assets are added correctly.""" assets_before = list(model.assets) # Create an application asset @@ -237,10 +237,10 @@ def test_model_add_asset(model: Model): assert p1 in model.assets -def test_model_add_asset_with_id_set(model): +def test_model_add_asset_with_id_set(model) -> None: """Make sure assets are added and next_id correctly updated - when id is set explicitly in method call""" - + when id is set explicitly in method call. + """ p1 = create_application_asset(model, 'Program 1') p1_id = model.next_id + 10 model.add_asset(p1, asset_id=p1_id) @@ -259,10 +259,9 @@ def test_model_add_asset_with_id_set(model): assert p2 not in model.assets -def test_model_add_asset_duplicate_name(model: Model): - """Add several assets with the same name to the model""" - - asset_name = "MyProgram" +def test_model_add_asset_duplicate_name(model: Model) -> None: + """Add several assets with the same name to the model.""" + asset_name = 'MyProgram' # Add a new asset p1 = create_application_asset(model, asset_name) @@ -275,7 +274,7 @@ def test_model_add_asset_duplicate_name(model: Model): model.add_asset(p2) assert len(model.assets) == 2 # Is this expected - shouldn't p2 have same name as p1? - assert model.assets[1].name == f"{asset_name}:{p2.id}" + assert model.assets[1].name == f'{asset_name}:{p2.id}' # Add asset again while not allowing duplicates, expect ValueError with pytest.raises(ValueError): @@ -284,9 +283,8 @@ def test_model_add_asset_duplicate_name(model: Model): assert len(model.assets) == 2 -def test_model_remove_asset(model: Model): - """Remove assets from a model""" - +def test_model_remove_asset(model: Model) -> None: + """Remove assets from a model.""" # Add two program assets to the model p1 = create_application_asset(model, 'Program 1') p2 = create_application_asset(model, 'Program 2') @@ -302,9 +300,8 @@ def test_model_remove_asset(model: Model): assert len(model.assets) == num_assets_before - 1 -def test_model_remove_nonexisting_asset(model: Model): - """Removing a non existing asset leads to lookup error""" - +def test_model_remove_nonexisting_asset(model: Model) -> None: + """Removing a non existing asset leads to lookup error.""" # Create an asset but don't add it to the model before removing it p1 = create_application_asset(model, 'Program 1') p1.id = 1 # Needs id to avoid crash in log statement @@ -312,9 +309,8 @@ def test_model_remove_nonexisting_asset(model: Model): model.remove_asset(p1) -def test_model_add_association(model: Model): - """Make sure associations work as intended""" - +def test_model_add_association(model: Model) -> None: + """Make sure associations work as intended.""" # Create two assets p1 = create_application_asset(model, 'Program 1') p1_id = model.next_id @@ -327,9 +323,12 @@ def test_model_add_association(model: Model): # Create an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) associations_before = list(model.associations) @@ -345,9 +344,8 @@ def test_model_add_association(model: Model): assert association in p2.associations -def test_model_add_appexecution_association_two_assets(model: Model): - """coreLang specifies that AppExecution only can have one 'left' asset""" - +def test_model_add_appexecution_association_two_assets(model: Model) -> None: + """CoreLang specifies that AppExecution only can have one 'left' asset.""" # Add program assets p1 = create_application_asset(model, 'Program 1') p1_id = model.next_id @@ -361,15 +359,17 @@ def test_model_add_appexecution_association_two_assets(model: Model): # will raise error because two assets (p1,p2) # are not allowed in the left field for AppExecution create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1, p2], right_assets=[p1] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1, p2], + right_assets=[p1], ) -def test_model_add_association_duplicate(model: Model): - """Make sure same association is not added twice""" - +def test_model_add_association_duplicate(model: Model) -> None: + """Make sure same association is not added twice.""" # Create three data assets d1 = create_data_asset(model, 'Data 1') d1_id = model.next_id @@ -385,21 +385,30 @@ def test_model_add_association_duplicate(model: Model): # Create an association between (d1, d2) and d3 association1 = create_association( - model, assoc_type=DATA_CONTAIN_ASSOC_NAME, - left_fieldname="containingData", right_fieldname="containedData", - left_assets=[d1, d2], right_assets=[d3] + model, + assoc_type=DATA_CONTAIN_ASSOC_NAME, + left_fieldname='containingData', + right_fieldname='containedData', + left_assets=[d1, d2], + right_assets=[d3], ) # Create an identical association, but from just d2 association2 = create_association( - model, assoc_type=DATA_CONTAIN_ASSOC_NAME, - left_fieldname="containingData", right_fieldname="containedData", - left_assets=[d2], right_assets=[d3] + model, + assoc_type=DATA_CONTAIN_ASSOC_NAME, + left_fieldname='containingData', + right_fieldname='containedData', + left_assets=[d2], + right_assets=[d3], ) # Create association with duplicate assets in both fields association3 = create_association( - model, assoc_type=DATA_CONTAIN_ASSOC_NAME, - left_fieldname="containingData", right_fieldname="containedData", - left_assets=[d2, d2], right_assets=[d3, d3] + model, + assoc_type=DATA_CONTAIN_ASSOC_NAME, + left_fieldname='containingData', + right_fieldname='containedData', + left_assets=[d2, d2], + right_assets=[d3, d3], ) # Add the first association to the model - no problem @@ -412,7 +421,7 @@ def test_model_add_association_duplicate(model: Model): model.add_association(association1) assert str(e.value) == ( - 'Identical association %s already exists' % DATA_CONTAIN_ASSOC_NAME + f'Identical association {DATA_CONTAIN_ASSOC_NAME} already exists' ) # Add the second (almost identical) association. @@ -422,8 +431,8 @@ def test_model_add_association_duplicate(model: Model): model.add_association(association2) assert str(e.value) == ( - 'Association type %s already exists' - ' between Data 2 and Data 3' % DATA_CONTAIN_ASSOC_NAME + f'Association type {DATA_CONTAIN_ASSOC_NAME} already exists' + ' between Data 2 and Data 3' ) # Add the third association, should fail because of duplicate @@ -433,19 +442,22 @@ def test_model_add_association_duplicate(model: Model): assert len(model.associations) == 1 -def test_model_remove_association(model: Model): - """Make sure association can be removed""" +def test_model_remove_association(model: Model) -> None: + """Make sure association can be removed.""" # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) @@ -465,13 +477,16 @@ def test_model_remove_association(model: Model): assert association not in p2.associations -def test_model_remove_association_nonexisting(model: Model): - """Make sure non existing association can not be removed""" +def test_model_remove_association_nonexisting(model: Model) -> None: + """Make sure non existing association can not be removed.""" # Create the association but don't add it association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[], right_assets=[] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[], + right_assets=[], ) # No associations exists @@ -482,21 +497,24 @@ def test_model_remove_association_nonexisting(model: Model): model.remove_association(association) -def test_model_remove_asset_from_association(model: Model): +def test_model_remove_asset_from_association(model: Model) -> None: """Make sure we can remove asset from association and that - associations with no assets on any 'side' is removed""" - + associations with no assets on any 'side' is removed. + """ # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) # Create and add association from p1 to p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) @@ -508,27 +526,28 @@ def test_model_remove_asset_from_association(model: Model): assert association not in model.associations -def test_model_remove_asset_from_association_nonexisting_asset( - model: Model - ): +def test_model_remove_asset_from_association_nonexisting_asset(model: Model) -> None: """Make sure error is thrown if deleting non existing asset - from association""" - + from association. + """ # Create 4 applications and add 3 of them to model - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") - p3 = create_application_asset(model, "Program 3") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') + p3 = create_application_asset(model, 'Program 3') p4 = create_application_asset(model, 'Program 4') model.add_asset(p1) model.add_asset(p2) model.add_asset(p3) - p4.id = 1 # ID is required, otherwise crash in log statement + p4.id = 1 # ID is required, otherwise crash in log statement # Create an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) @@ -543,31 +562,33 @@ def test_model_remove_asset_from_association_nonexisting_asset( def test_model_remove_asset_from_association_nonexisting_association( - model: Model - ): + model: Model, +) -> None: """Make sure error is thrown if deleting non existing asset - from association""" - + from association. + """ # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) # Create (but don't add!) an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) # We are removing an association that was never created -> LookupError with pytest.raises(LookupError): model.remove_asset_from_association(p1, association) -def test_model_add_attacker(model: Model): - """Test functionality to add an attacker to the model""" - +def test_model_add_attacker(model: Model) -> None: + """Test functionality to add an attacker to the model.""" # Add attacker 1 attacker1 = AttackerAttachment() attacker1.entry_points = [] @@ -586,22 +607,19 @@ def test_model_add_attacker(model: Model): asset_id = 1 attack_steps = ['attemptCredentialsReuse'] - attacker2.entry_points = [ - (asset_id, attack_steps) - ] + attacker2.entry_points = [(asset_id, attack_steps)] model.add_attacker(attacker2, attacker_id=attacker2_id) - assert attacker2.name == f'Attacker:{attacker2_id}' -def test_model_get_asset_by_id(model: Model): +def test_model_get_asset_by_id(model: Model) -> None: """Make sure correct asset is returned or None - if no asset with that ID exists""" - + if no asset with that ID exists. + """ # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) @@ -611,26 +629,26 @@ def test_model_get_asset_by_id(model: Model): assert model.get_asset_by_id(1337) is None -def test_model_get_asset_by_name(model: Model): +def test_model_get_asset_by_name(model: Model) -> None: """Make sure correct asset is returned or None - if no asset with that name exists""" - + if no asset with that name exists. + """ # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) # Correct assets removed and None if asset with that name not exists assert model.get_asset_by_name(p1.name) == p1 assert model.get_asset_by_name(p2.name) == p2 - assert model.get_asset_by_name("Program 3") is None + assert model.get_asset_by_name('Program 3') is None -def test_model_get_attacker_by_id(model: Model): +def test_model_get_attacker_by_id(model: Model) -> None: """Make sure correct attacker is returned of None - if no attacker with that ID exists""" - + if no attacker with that ID exists. + """ # Add attacker 1 attacker1 = AttackerAttachment() attacker1.entry_points = [] @@ -641,45 +659,45 @@ def test_model_get_attacker_by_id(model: Model): assert model.get_attacker_by_id(1337) is None -def test_model_get_associated_assets_by_fieldname(model: Model): +def test_model_get_associated_assets_by_fieldname(model: Model) -> None: """Make sure assets associated to the asset through the given - field_name are returned""" - + field_name are returned. + """ # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) # Create and add an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) # Since p2 is in an association with p1 through 'appExecutedApps' # p2 should be returned as an associated asset - ret = model.get_associated_assets_by_field_name( - p1, "appExecutedApps") + ret = model.get_associated_assets_by_field_name(p1, 'appExecutedApps') assert p2 in ret # Other fieldname from p2 to p1 - ret = model.get_associated_assets_by_field_name( - p2, "hostApp") + ret = model.get_associated_assets_by_field_name(p2, 'hostApp') assert p1 in ret # Non existing field name should give no assets - ret = model.get_associated_assets_by_field_name( - p1, "bogusFieldName") + ret = model.get_associated_assets_by_field_name(p1, 'bogusFieldName') assert ret == [] -def test_model_asset_to_dict(model: Model): - """Make sure assets are converted to dictionaries correctly""" +def test_model_asset_to_dict(model: Model) -> None: + """Make sure assets are converted to dictionaries correctly.""" # Create and add asset - p1 = create_application_asset(model, "Program 1") + p1 = create_application_asset(model, 'Program 1') model.add_asset(p1) # Tuple is returned @@ -696,12 +714,13 @@ def test_model_asset_to_dict(model: Model): assert p1_dict.get('type') == 'Application' # Default values should not be saved - assert p1_dict.get('defenses') == None + assert p1_dict.get('defenses') is None -def test_model_asset_with_nondefault_defense_to_dict(model: Model): - """Make sure assets are converted to dictionaries correctly""" + +def test_model_asset_with_nondefault_defense_to_dict(model: Model) -> None: + """Make sure assets are converted to dictionaries correctly.""" # Create and add asset - p1 = create_application_asset(model, "Program 1") + p1 = create_application_asset(model, 'Program 1') p1.notPresent = 1.0 model.add_asset(p1) @@ -719,49 +738,47 @@ def test_model_asset_with_nondefault_defense_to_dict(model: Model): assert p1_dict.get('type') == 'Application' # Default values for 'Application' defenses in coreLang - assert p1_dict.get('defenses') == { - 'notPresent': 1.0 - } + assert p1_dict.get('defenses') == {'notPresent': 1.0} -def test_model_association_to_dict(model: Model): - """Make sure associations are converted to dictionaries correctly""" +def test_model_association_to_dict(model: Model) -> None: + """Make sure associations are converted to dictionaries correctly.""" # Create and add 2 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') model.add_asset(p1) model.add_asset(p2) # Create and add an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) association_dict = model.association_to_dict(association) - association_type = list(association_dict.keys())[0] + association_type = next(iter(association_dict.keys())) assert association_type == APP_EXEC_ASSOC_NAME - assert association_dict[association_type ] == { + assert association_dict[association_type] == { 'hostApp': {p1.id: str(p1.name)}, - 'appExecutedApps': {p2.id: str(p2.name)} + 'appExecutedApps': {p2.id: str(p2.name)}, } -def test_model_attacker_to_dict(model: Model): - """Make sure attackers get correct format and values""" - +def test_model_attacker_to_dict(model: Model) -> None: + """Make sure attackers get correct format and values.""" # Create and add an asset - p1 = create_application_asset(model, "Program 1") + p1 = create_application_asset(model, 'Program 1') model.add_asset(p1) # Add attacker 1 attacker = AttackerAttachment() - attack_steps = ["attemptCredentialsReuse"] - attacker.entry_points = [ - (p1, attack_steps) - ] + attack_steps = ['attemptCredentialsReuse'] + attacker.entry_points = [(p1, attack_steps)] model.add_attacker(attacker) # Convert the attacker to a dictionary and make sure @@ -776,7 +793,8 @@ def test_model_attacker_to_dict(model: Model): # attacker should be attached to p1, therefore p1s # id should be a key in the entry_points_dict - assert p1.name is not None and entry_points_dict + assert p1.name is not None + assert entry_points_dict assert p1.name in entry_points_dict # The given steps should be inside the entry_point of @@ -784,31 +802,31 @@ def test_model_attacker_to_dict(model: Model): assert entry_points_dict[p1.name]['attack_steps'] == attack_steps -def test_serialize(model: Model): - """Put all to_dict methods together and see that they work""" - +def test_serialize(model: Model) -> None: + """Put all to_dict methods together and see that they work.""" # Create and add 3 applications - p1 = create_application_asset(model, "Program 1") - p2 = create_application_asset(model, "Program 2") - p3 = create_application_asset(model, "Program 3") + p1 = create_application_asset(model, 'Program 1') + p2 = create_application_asset(model, 'Program 2') + p3 = create_application_asset(model, 'Program 3') model.add_asset(p1) model.add_asset(p2) model.add_asset(p3) # Create and add an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) # Add attacker attacker = AttackerAttachment() - attack_steps = ["attemptCredentialsReuse"] - attacker.entry_points = [ - (p1, attack_steps) - ] + attack_steps = ['attemptCredentialsReuse'] + attacker.entry_points = [(p1, attack_steps)] model.add_attacker(attacker) model_dict = model._to_dict() @@ -816,89 +834,84 @@ def test_serialize(model: Model): # to_dict will create map from asset id to asset dict # (dict is second value of tuple returned from asset_to_dict) for asset in [p1, p2, p3]: - assert model_dict['assets'][asset.id] == \ - model.asset_to_dict(asset)[1] + assert model_dict['assets'][asset.id] == model.asset_to_dict(asset)[1] # associations are added as they are created by association_to_dict - assert model_dict['associations'] == \ - [model.association_to_dict(association)] + assert model_dict['associations'] == [model.association_to_dict(association)] # attackers are added similar to assets (id maps to attacker dict) - assert model_dict['attackers'][attacker.id] == \ - model.attacker_to_dict(attacker)[1] + assert model_dict['attackers'][attacker.id] == model.attacker_to_dict(attacker)[1] # Meta data should also be added assert model_dict['metadata']['name'] == model.name - assert model_dict['metadata']['langVersion'] == \ - model.lang_classes_factory.lang_graph.metadata['version'] - assert model_dict['metadata']['langID'] == \ - model.lang_classes_factory.lang_graph.metadata['id'] + assert ( + model_dict['metadata']['langVersion'] + == model.lang_classes_factory.lang_graph.metadata['version'] + ) + assert ( + model_dict['metadata']['langID'] + == model.lang_classes_factory.lang_graph.metadata['id'] + ) -def test_model_save_and_load_model_from_scratch(model: Model): +def test_model_save_and_load_model_from_scratch(model: Model) -> None: """Create a model, save it to file, load it from file and compare them - Note: will create file in /tmp + Note: will create file in /tmp. """ - # Create and add 3 applications - p1 = create_application_asset(model, "Program 1") - p1.extras = {"testing": "testing"} - p2 = create_application_asset(model, "Program 2") - p3 = create_application_asset(model, "Program 3") + p1 = create_application_asset(model, 'Program 1') + p1.extras = {'testing': 'testing'} + p2 = create_application_asset(model, 'Program 2') + p3 = create_application_asset(model, 'Program 3') model.add_asset(p1) model.add_asset(p2) model.add_asset(p3) # Create and add an association between p1 and p2 association = create_association( - model, assoc_type=APP_EXEC_ASSOC_NAME, - left_fieldname="hostApp", right_fieldname="appExecutedApps", - left_assets=[p1], right_assets=[p2] + model, + assoc_type=APP_EXEC_ASSOC_NAME, + left_fieldname='hostApp', + right_fieldname='appExecutedApps', + left_assets=[p1], + right_assets=[p2], ) model.add_association(association) # Add attacker attacker = AttackerAttachment() - attack_steps = ["attemptCredentialsReuse"] - attacker.entry_points = [ - (p1, attack_steps) - ] + attack_steps = ['attemptCredentialsReuse'] + attacker.entry_points = [(p1, attack_steps)] model.add_attacker(attacker) - for model_path in ("/tmp/test.json", "/tmp/test.yaml", "/tmp/test.yml"): + for model_path in ('/tmp/test.json', '/tmp/test.yaml', '/tmp/test.yml'): # Mock open() function so no real files are written to filesystem model.save_to_file(model_path) # Create a new model by loading old model from file - new_model = Model.load_from_file( - model_path, - model.lang_classes_factory - ) + new_model = Model.load_from_file(model_path, model.lang_classes_factory) assert new_model._to_dict() == model._to_dict() -def test_model_save_and_load_model_example_model(model): +def test_model_save_and_load_model_example_model(model) -> None: """Load the simple_example_model.json from testdata, store it, compare - Note: will create file in /tmp""" - + Note: will create file in /tmp. + """ # Load from example file model = Model.load_from_file( - path_testdata("simple_example_model.yml"), - model.lang_classes_factory + path_testdata('simple_example_model.yml'), model.lang_classes_factory ) # Save to file model.save_to_file('/tmp/test.json') # Create new model and load from previously saved file - new_model = Model.load_from_file( - '/tmp/test.json', - model.lang_classes_factory - ) + new_model = Model.load_from_file('/tmp/test.json', model.lang_classes_factory) assert new_model._to_dict() == model._to_dict() + # TODO: Re-enable this test when the updater translator has been updated(oh, # the irony). # def test_model_load_older_version_example_model(model): diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 7a452c3a..b38acec1 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,8 +1,10 @@ -from maltoolbox.wrappers import create_attack_graph from conftest import path_testdata -def test_create_attack_graph(): - """See that the create attack graph wrapper works""" +from maltoolbox.wrappers import create_attack_graph + + +def test_create_attack_graph() -> None: + """See that the create attack graph wrapper works.""" mar = path_testdata('org.mal-lang.coreLang-1.0.0.mar') model = path_testdata('simple_example_model.yml') diff --git a/tests/translators/test_securicad_translator.py b/tests/translators/test_securicad_translator.py index 2ca3125d..5a48e43b 100644 --- a/tests/translators/test_securicad_translator.py +++ b/tests/translators/test_securicad_translator.py @@ -1,12 +1,5 @@ -"""Unit tests for AttackGraph functionality""" +"""Unit tests for AttackGraph functionality.""" -import pytest - -from conftest import path_testdata - -from maltoolbox.model import Model -from maltoolbox.language import LanguageClassesFactory -from maltoolbox.translators import securicad # TODO Re-enable this when the securicad translator has been updated. # def test_securicad_translator(corelang_lang_graph):