From 3bf759a8caa4c1c87a3c4aa0ecca89b42e1178ab Mon Sep 17 00:00:00 2001 From: RA80533 <32469082+RA80533@users.noreply.github.com> Date: Mon, 5 Jun 2023 23:36:48 +0000 Subject: [PATCH] db.py: Refactor type annotations --- spiderfoot/db.py | 118 +++++++++++++++++++++++++---------------------- 1 file changed, 63 insertions(+), 55 deletions(-) diff --git a/spiderfoot/db.py b/spiderfoot/db.py index b206659318..8a35780dd1 100644 --- a/spiderfoot/db.py +++ b/spiderfoot/db.py @@ -17,6 +17,9 @@ import sqlite3 import threading import time +import typing + +from .event import SpiderFootEvent class SpiderFootDb: @@ -28,8 +31,8 @@ class SpiderFootDb: dbhLock (_thread.RLock): thread lock on database handle """ - dbh = None - conn = None + dbh: sqlite3.Cursor + conn: sqlite3.Connection # Prevent multithread access to sqlite database dbhLock = threading.RLock() @@ -283,7 +286,7 @@ class SpiderFootDb: ['WIKIPEDIA_PAGE_EDIT', 'Wikipedia Page Edit', 0, 'DESCRIPTOR'], ] - def __init__(self, opts: dict, init: bool = False) -> None: + def __init__(self, opts: typing.Dict[str, str], init: bool = False) -> None: """Initialize database and create handle to the SQLite database file. Creates the database file if it does not exist. Creates database schema if it does not exist. @@ -427,7 +430,7 @@ def close(self) -> None: with self.dbhLock: self.dbh.close() - def vacuumDB(self) -> None: + def vacuumDB(self) -> bool: """Vacuum the database. Clears unused database file pages. Returns: @@ -445,7 +448,7 @@ def vacuumDB(self) -> None: raise IOError("SQL error encountered when vacuuming the database") from e return False - def search(self, criteria: dict, filterFp: bool = False) -> list: + def search(self, criteria: typing.Dict[str, str], filterFp: bool = False) -> typing.List[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, typing.Optional[str], str, str, str, int, int]]: """Search database. Args: @@ -488,7 +491,7 @@ def search(self, criteria: dict, filterFp: bool = False) -> list: if len(criteria) == 1: raise ValueError("Only one search criteria provided; expected at least two") - qvars = list() + qvars: typing.List[str] = list() qry = "SELECT ROUND(c.generated) AS generated, c.data, \ s.data as 'source_data', \ c.module, c.type, c.confidence, c.visibility, c.risk, c.hash, \ @@ -528,7 +531,7 @@ def search(self, criteria: dict, filterFp: bool = False) -> list: except sqlite3.Error as e: raise IOError("SQL error encountered when fetching search results") from e - def eventTypes(self) -> list: + def eventTypes(self) -> typing.List[typing.Tuple[str, str, int, str]]: """Get event types. Returns: @@ -546,7 +549,7 @@ def eventTypes(self) -> list: except sqlite3.Error as e: raise IOError("SQL error encountered when retrieving event types") from e - def scanLogEvents(self, batch: list) -> bool: + def scanLogEvents(self, batch: typing.List[typing.Tuple[str, str, str, typing.Optional[str], float]]) -> bool: """Logs a batch of events to the database. Args: @@ -560,7 +563,7 @@ def scanLogEvents(self, batch: list) -> bool: bool: Whether the logging operation succeeded """ - inserts = [] + inserts: typing.List[typing.Tuple[str, float, str, str, str]] = [] for instanceId, classification, message, component, logTime in batch: if not isinstance(instanceId, str): @@ -592,7 +595,7 @@ def scanLogEvents(self, batch: list) -> bool: return False return True - def scanLogEvent(self, instanceId: str, classification: str, message: str, component: str = None) -> None: + def scanLogEvent(self, instanceId: str, classification: str, message: str, component: typing.Optional[str] = None) -> None: """Log an event to the database. Args: @@ -673,7 +676,7 @@ def scanInstanceCreate(self, instanceId: str, scanName: str, scanTarget: str) -> except sqlite3.Error as e: raise IOError("Unable to create scan instance in database") from e - def scanInstanceSet(self, instanceId: str, started: str = None, ended: str = None, status: str = None) -> None: + def scanInstanceSet(self, instanceId: str, started: typing.Optional[str] = None, ended: typing.Optional[str] = None, status: typing.Optional[str] = None) -> None: """Update the start time, end time or status (or all 3) of a scan instance. Args: @@ -690,7 +693,7 @@ def scanInstanceSet(self, instanceId: str, started: str = None, ended: str = Non if not isinstance(instanceId, str): raise TypeError(f"instanceId is {type(instanceId)}; expected str()") from None - qvars = list() + qvars: typing.List[str] = list() qry = "UPDATE tbl_scan_instance SET " if started is not None: @@ -716,14 +719,14 @@ def scanInstanceSet(self, instanceId: str, started: str = None, ended: str = Non except sqlite3.Error: raise IOError("Unable to set information for the scan instance.") from None - def scanInstanceGet(self, instanceId: str) -> list: + def scanInstanceGet(self, instanceId: str) -> typing.Tuple[str, str, typing.Optional[int], typing.Optional[int], typing.Optional[int], str]: """Return info about a scan instance (name, target, created, started, ended, status) Args: instanceId (str): scan instance ID Returns: - list: scan instance info + tuple: scan instance info Raises: TypeError: arg type was invalid @@ -745,7 +748,7 @@ def scanInstanceGet(self, instanceId: str) -> list: except sqlite3.Error as e: raise IOError("SQL error encountered when retrieving scan instance") from e - def scanResultSummary(self, instanceId: str, by: str = "type") -> list: + def scanResultSummary(self, instanceId: str, by: str = "type") -> typing.Union[typing.List[typing.Tuple[str, str, int, int, int]], typing.List[typing.Tuple[typing.Optional[str], str, int, int, int]]]: """Obtain a summary of the results, filtered by event type, module or entity. Args: @@ -770,6 +773,8 @@ def scanResultSummary(self, instanceId: str, by: str = "type") -> list: if by not in ["type", "module", "entity"]: raise ValueError(f"Invalid filter by value: {by}") from None + qry = "" + if by == "type": qry = "SELECT r.type, e.event_descr, MAX(ROUND(generated)) AS last_in, \ count(*) AS total, count(DISTINCT r.data) as utotal FROM \ @@ -799,7 +804,7 @@ def scanResultSummary(self, instanceId: str, by: str = "type") -> list: except sqlite3.Error as e: raise IOError("SQL error encountered when fetching result summary") from e - def scanCorrelationSummary(self, instanceId: str, by: str = "rule") -> list: + def scanCorrelationSummary(self, instanceId: str, by: str = "rule") -> typing.Union[typing.List[typing.Tuple[str, int]], typing.List[typing.Tuple[str, str, str, str, int]]]: """Obtain a summary of the correlations, filtered by rule or risk Args: @@ -824,6 +829,8 @@ def scanCorrelationSummary(self, instanceId: str, by: str = "rule") -> list: if by not in ["rule", "risk"]: raise ValueError(f"Invalid filter by value: {by}") from None + qry = "" + if by == "risk": qry = "SELECT rule_risk, count(*) AS total FROM \ tbl_scan_correlation_results \ @@ -844,7 +851,7 @@ def scanCorrelationSummary(self, instanceId: str, by: str = "rule") -> list: except sqlite3.Error as e: raise IOError("SQL error encountered when fetching correlation summary") from e - def scanCorrelationList(self, instanceId: str) -> list: + def scanCorrelationList(self, instanceId: str) -> typing.List[typing.Tuple[str, str, str, str, str, str, str, int]]: """Obtain a list of the correlations from a scan Args: @@ -879,21 +886,21 @@ def scanCorrelationList(self, instanceId: str) -> list: def scanResultEvent( self, instanceId: str, - eventType: str = 'ALL', - srcModule: str = None, - data: list = None, - sourceId: list = None, - correlationId: str = None, + eventType: typing.Union[str, typing.List[str]] = 'ALL', + srcModule: typing.Optional[typing.Union[str, typing.List[str]]] = None, + data: typing.Optional[typing.Union[str, typing.List[str]]] = None, + sourceId: typing.Optional[typing.Union[str, typing.List[str]]] = None, + correlationId: typing.Optional[str] = None, filterFp: bool = False - ) -> list: + ) -> typing.List[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int]]: """Obtain the data for a scan and event type. Args: instanceId (str): scan instance ID - eventType (str): filter by event type - srcModule (str): filter by the generating module - data (list): filter by the data - sourceId (list): filter by the ID of the source event + eventType (str | list): filter by event type + srcModule (str | list): filter by the generating module + data (str | list): filter by the data + sourceId (str | list): filter by the ID of the source event correlationId (str): filter by the ID of a correlation result filterFp (bool): filter false positives @@ -974,7 +981,7 @@ def scanResultEvent( except sqlite3.Error as e: raise IOError("SQL error encountered when fetching result events") from e - def scanResultEventUnique(self, instanceId: str, eventType: str = 'ALL', filterFp: bool = False) -> list: + def scanResultEventUnique(self, instanceId: str, eventType: str = 'ALL', filterFp: bool = False) -> typing.List[typing.Tuple[typing.Optional[str], str, int]]: """Obtain a unique list of elements. Args: @@ -1016,7 +1023,7 @@ def scanResultEventUnique(self, instanceId: str, eventType: str = 'ALL', filterF except sqlite3.Error as e: raise IOError("SQL error encountered when fetching unique result events") from e - def scanLogs(self, instanceId: str, limit: int = None, fromRowId: int = 0, reverse: bool = False) -> list: + def scanLogs(self, instanceId: str, limit: typing.Optional[int] = None, fromRowId: int = 0, reverse: bool = False) -> typing.List[typing.Tuple[int, typing.Optional[str], str, typing.Optional[str], int]]: """Get scan logs. Args: @@ -1062,7 +1069,7 @@ def scanLogs(self, instanceId: str, limit: int = None, fromRowId: int = 0, rever except sqlite3.Error as e: raise IOError("SQL error encountered when fetching scan logs") from e - def scanErrors(self, instanceId: str, limit: int = 0) -> list: + def scanErrors(self, instanceId: str, limit: int = 0) -> typing.List[typing.Tuple[int, typing.Optional[str], typing.Optional[str]]]: """Get scan errors. Args: @@ -1134,7 +1141,7 @@ def scanInstanceDelete(self, instanceId: str) -> bool: return True - def scanResultsUpdateFP(self, instanceId: str, resultHashes: list, fpFlag: int) -> bool: + def scanResultsUpdateFP(self, instanceId: str, resultHashes: typing.List[str], fpFlag: int) -> bool: """Set the false positive flag for a result. Args: @@ -1173,7 +1180,7 @@ def scanResultsUpdateFP(self, instanceId: str, resultHashes: list, fpFlag: int) return True - def configSet(self, optMap: dict = {}) -> bool: + def configSet(self, optMap: typing.Dict[str, str] = {}) -> bool: """Store the default configuration in the database. Args: @@ -1217,7 +1224,7 @@ def configSet(self, optMap: dict = {}) -> bool: return True - def configGet(self) -> dict: + def configGet(self) -> typing.Dict[str, str]: """Retreive the config from the database Returns: @@ -1229,7 +1236,7 @@ def configGet(self) -> dict: qry = "SELECT scope, opt, val FROM tbl_config" - retval = dict() + retval: typing.Dict[str, str] = dict() with self.dbhLock: try: @@ -1261,11 +1268,11 @@ def configClear(self) -> None: except sqlite3.Error as e: raise IOError("Unable to clear configuration from the database") from e - def scanConfigSet(self, scan_id, optMap=dict()) -> None: + def scanConfigSet(self, scan_id: str, optMap: typing.Dict[str, str] = dict()) -> None: """Store a configuration value for a scan. Args: - scan_id (int): scan instance ID + scan_id (str): scan instance ID optMap (dict): config options Raises: @@ -1302,7 +1309,7 @@ def scanConfigSet(self, scan_id, optMap=dict()) -> None: except sqlite3.Error as e: raise IOError("SQL error encountered when storing config, aborting") from e - def scanConfigGet(self, instanceId: str) -> dict: + def scanConfigGet(self, instanceId: str) -> typing.Dict[str, str]: """Retrieve configuration data for a scan component. Args: @@ -1323,7 +1330,7 @@ def scanConfigGet(self, instanceId: str) -> dict: WHERE scan_instance_id = ? ORDER BY component, opt" qvars = [instanceId] - retval = dict() + retval: typing.Dict[str, str] = dict() with self.dbhLock: try: @@ -1337,7 +1344,7 @@ def scanConfigGet(self, instanceId: str) -> dict: except sqlite3.Error as e: raise IOError("SQL error encountered when fetching configuration") from e - def scanEventStore(self, instanceId: str, sfEvent, truncateSize: int = 0) -> None: + def scanEventStore(self, instanceId: str, sfEvent: SpiderFootEvent, truncateSize: int = 0) -> None: """Store an event in the database. Args: @@ -1350,8 +1357,6 @@ def scanEventStore(self, instanceId: str, sfEvent, truncateSize: int = 0) -> Non ValueError: arg value was invalid IOError: database I/O failed """ - from spiderfoot import SpiderFootEvent - if not isinstance(instanceId, str): raise TypeError(f"instanceId is {type(instanceId)}; expected str()") from None @@ -1435,7 +1440,7 @@ def scanEventStore(self, instanceId: str, sfEvent, truncateSize: int = 0) -> Non except sqlite3.Error as e: raise IOError(f"SQL error encountered when storing event data ({self.dbh})") from e - def scanInstanceList(self) -> list: + def scanInstanceList(self) -> typing.List[typing.Tuple[str, str, str, typing.Optional[int], typing.Optional[int], typing.Optional[int], str, int]]: """List all previously run scans. Returns: @@ -1466,7 +1471,7 @@ def scanInstanceList(self) -> list: except sqlite3.Error as e: raise IOError("SQL error encountered when fetching scan list") from e - def scanResultHistory(self, instanceId: str) -> list: + def scanResultHistory(self, instanceId: str) -> typing.List[typing.Tuple[typing.Optional[str], str, int]]: """History of data from the scan. Args: @@ -1495,7 +1500,7 @@ def scanResultHistory(self, instanceId: str) -> list: except sqlite3.Error as e: raise IOError(f"SQL error encountered when fetching history for scan {instanceId}") from e - def scanElementSourcesDirect(self, instanceId: str, elementIdList: list) -> list: + def scanElementSourcesDirect(self, instanceId: str, elementIdList: typing.List[str]) -> typing.List[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int, str, str, str]]: """Get the source IDs, types and data for a set of IDs. Args: @@ -1516,7 +1521,7 @@ def scanElementSourcesDirect(self, instanceId: str, elementIdList: list) -> list if not isinstance(elementIdList, list): raise TypeError(f"elementIdList is {type(elementIdList)}; expected list()") from None - hashIds = [] + hashIds: typing.List[str] = [] for hashId in elementIdList: if not hashId: continue @@ -1546,7 +1551,7 @@ def scanElementSourcesDirect(self, instanceId: str, elementIdList: list) -> list except sqlite3.Error as e: raise IOError("SQL error encountered when getting source element IDs") from e - def scanElementChildrenDirect(self, instanceId: str, elementIdList: list) -> list: + def scanElementChildrenDirect(self, instanceId: str, elementIdList: typing.List[str]) -> typing.List[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int]]: """Get the child IDs, types and data for a set of IDs. Args: @@ -1567,7 +1572,7 @@ def scanElementChildrenDirect(self, instanceId: str, elementIdList: list) -> lis if not isinstance(elementIdList, list): raise TypeError(f"elementIdList is {type(elementIdList)}; expected list()") - hashIds = [] + hashIds: typing.List[str] = [] for hashId in elementIdList: if not hashId: continue @@ -1595,7 +1600,7 @@ def scanElementChildrenDirect(self, instanceId: str, elementIdList: list) -> lis except sqlite3.Error as e: raise IOError("SQL error encountered when getting child element IDs") from e - def scanElementSourcesAll(self, instanceId: str, childData: list) -> list: + def scanElementSourcesAll(self, instanceId: str, childData: typing.List[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int]]) -> typing.List[typing.Union[typing.Dict[str, typing.Union[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int], typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int, str, str, str]]], typing.Dict[str, typing.List[str]]]]: """Get the full set of upstream IDs which are parents to the supplied set of IDs. Args: @@ -1621,9 +1626,9 @@ def scanElementSourcesAll(self, instanceId: str, childData: list) -> list: # Get the first round of source IDs for the leafs keepGoing = True - nextIds = list() - datamap = dict() - pc = dict() + nextIds: typing.List[str] = list() + datamap: typing.Dict[str, typing.Union[typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int], typing.Tuple[int, typing.Optional[str], typing.Optional[str], str, str, int, int, int, str, str, str, str, str, int, int, str, str, str]]] = dict() + pc: typing.Dict[str, typing.List[str]] = dict() for row in childData: # these must be unique values! @@ -1663,10 +1668,13 @@ def scanElementSourcesAll(self, instanceId: str, childData: list) -> list: if parentId != "ROOT": keepGoing = True + assert parentId + assert row + datamap[parentId] = row return [datamap, pc] - def scanElementChildrenAll(self, instanceId: str, parentIds: list) -> list: + def scanElementChildrenAll(self, instanceId: str, parentIds: typing.List[str]) -> typing.List[str]: """Get the full set of downstream IDs which are children of the supplied set of IDs. Args: @@ -1689,9 +1697,9 @@ def scanElementChildrenAll(self, instanceId: str, parentIds: list) -> list: if not isinstance(parentIds, list): raise TypeError(f"parentIds is {type(parentIds)}; expected list()") - datamap = list() + datamap: typing.List[str] = list() keepGoing = True - nextIds = list() + nextIds: typing.List[str] = list() nextSet = self.scanElementChildrenDirect(instanceId, parentIds) for row in nextSet: @@ -1723,7 +1731,7 @@ def correlationResultCreate( ruleRisk: str, ruleYaml: str, correlationTitle: str, - eventHashes: list + eventHashes: typing.List[str], ) -> str: """Create a correlation result in the database.