diff --git a/Makefile b/Makefile index 52d0d3569..6094db00a 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,13 @@ update-dockers: ## Update docker images else \ echo "Skipping nebula-controller docker build."; \ fi + @echo "🐳 Building nebula-database docker image. Do you want to continue (overrides existing image)? (y/n)" + @read ans; if [ "$${ans:-N}" = y ]; then \ + docker build -t nebula-database -f nebula/database/adapters/postgress/docker/Dockerfile .; \ + docker build -t nebula-pgweb -f nebula/database/pgweb/Dockerfile .; \ + else \ + echo "Skipping nebula-database docker build."; \ + fi @echo "" @echo "🐳 Building nebula-frontend docker image. Do you want to continue (overrides existing image)? (y/n)" @read ans; if [ "$${ans:-N}" = y ]; then \ diff --git a/app/deployer.py b/app/deployer.py index 968ef62da..889421330 100644 --- a/app/deployer.py +++ b/app/deployer.py @@ -1,7 +1,9 @@ import json import logging import os +import secrets import signal +import string import subprocess import sys import threading @@ -15,9 +17,95 @@ from watchdog.observers import Observer from nebula.addons.env import check_environment -from nebula.controller.controller import TermEscapeCodeFormatter +from nebula.controller.hub import TermEscapeCodeFormatter from nebula.controller.scenarios import ScenarioManagement -from nebula.utils import DockerUtils, SocketUtils +from nebula.utils import DockerUtils, FileUtils, SocketUtils + +class CredentialManager: + """ + CredentialManager handles the generation, storage, and validation of environment-based credentials. + + This class is designed to manage credentials required for different system components like the frontend, + Grafana, and the database. It ensures that secure values are generated and persisted in a `.env` file if + they are not already defined in the environment. + + Attributes: + env_path (Path): Absolute path to the environment file where credentials will be stored. + + Typical usage example: + manager = CredentialManager() + manager.check_all_credentials() + """ + + def __init__(self, env_dir="app", env_filename=".env"): + """ + Initializes the CredentialManager and loads existing environment variables from file. + + Args: + env_dir (str): Directory where the .env file is located. Defaults to 'app'. + env_filename (str): Name of the environment file. Defaults to '.env'. + + Behavior: + - Sets up the absolute path to the .env file. + - Loads any existing environment variables from the file using `load_dotenv`. + """ + self.env_path = Path.cwd() / env_dir / env_filename + if os.path.exists(self.env_path): + logging.info(f"Loading environment variables from {self.env_path}") + load_dotenv(self.env_path, override=True) + + def generate_secure_password(self, length=20): + """ + Generates a cryptographically secure and readable password including symbols. + + Args: + length (int): Length of the password. Defaults to 20. + + Returns: + str: A randomly generated secure password, excluding confusing or problematic characters. + """ + alphabet = string.ascii_letters + string.digits + string.punctuation + for char in ['"', "'", "\\", "`", "|", "(", ")", "{", "}", "[", "]", "#"]: + alphabet = alphabet.replace(char, "") + return ''.join(secrets.choice(alphabet) for _ in range(length)) + + def check_credential(self, key, is_password=True): + """ + Checks if a given credential key is present in the environment. If not, generates and saves it. + + Args: + key (str): The environment variable key to check or create. + is_password (bool): If True, generates a secure password. If False, generates a hex token. Defaults to True. + + Behavior: + - If the key is missing, a value is generated and stored both in the environment and the `.env` file. + - If the key exists, no action is taken. + """ + if key not in os.environ: + logging.info(f"Generating value for {key}") + value = self.generate_secure_password(12) if is_password else secrets.token_hex(24) + os.environ[key] = value + logging.info(f"Saving {key} to {self.env_path}") + with self.env_path.open("a") as f: + f.write(f"{key}={value}\n") + else: + logging.info(f"{key} already set") + + def check_all_credentials(self): + """ + Checks and sets all required credentials for the application. + + This method should be called at startup to ensure all necessary keys are initialized. + + Includes: + - Frontend secret key + - Grafana admin password + - (Optional) Database password + """ + self.check_credential("SECRET_KEY", is_password=False) + self.check_credential("GF_SECURITY_ADMIN_PASSWORD") + self.check_credential("POSTGRES_PASSWORD") + self.check_credential("NEBULA_ADMIN_PASSWORD") class NebulaEventHandler(PatternMatchingEventHandler): @@ -289,17 +377,17 @@ def run_script(self, script): def kill_script_processes(self, pids_file): """ Forcefully terminates processes listed in a given PID file, including their child processes. - + Args: pids_file (str): Path to the file containing PIDs, one per line. - + Behavior: - Reads the PIDs from the file. - For each PID, checks if the process exists. - If it exists, kills all child processes recursively before killing the main process. - Handles and logs exceptions such as missing processes or invalid PID entries. - Logs warnings and errors appropriately. - + Typical use case: Used to clean up running processes related to a scenario or script that has been deleted or stopped. """ @@ -344,7 +432,7 @@ def run_observer(): """ Starts a watchdog observer to monitor the configuration directory for changes. - This function is typically used to execute additional scripts or trigger events + This function is typically used to execute additional scripts or trigger events during the execution of a federated learning session by monitoring file system changes. Main functionalities: @@ -357,7 +445,7 @@ def run_observer(): - Trigger specific actions during a federation lifecycle. Note: - The observer runs in a blocking mode and will keep the process alive + The observer runs in a blocking mode and will keep the process alive until manually stopped or interrupted. """ # Watchdog for running additional scripts in the host machine (i.e. during the execution of a federation) @@ -373,8 +461,8 @@ class Deployer: """ Handles the configuration and initialization of deployment parameters for the NEBULA system. - This class reads and stores various deployment-related settings such as port assignments, - environment paths, logging configuration, and system mode (production, development, or simulation). + This class reads and stores various deployment-related settings such as port assignments, + environment paths, logging configuration, and system mode (production or development). Main functionalities: - Parses and validates input arguments for deployment. @@ -384,7 +472,7 @@ class Deployer: Typical use cases: - Used to deploy the NEBULA system components with the correct configuration. - - Enables deployment in different environments (e.g., local simulation, production, development). + - Enables deployment in different environments (e.g., production, development). Attributes: - controller_port (int): Port for the main controller service. @@ -395,9 +483,7 @@ class Deployer: - statistics_port (int): Port for the statistics service. - production (bool): Flag indicating if the system is in production mode. - dev (bool): Flag indicating if the system is in development mode. - - advanced_analytics (bool): Enables advanced analytics modules. - databases_dir (str): Path to the database directory. - - simulation (str): Simulation scenario path. - config_dir (str): Path to the configuration directory. - log_dir (str): Path to the logs directory. - env_path (str): Path to the Python environment. @@ -409,8 +495,150 @@ class Deployer: Note: This class does not launch any services directly; it only prepares and stores configuration. """ + + DEPLOYER_PID_FILE = os.path.join(os.path.dirname(__file__), "deployer.pid") + METADATA_FILE = os.path.join(os.path.dirname(__file__), "deployer.metadata") + + @staticmethod + def _read_metadata(): + try: + with open(Deployer.METADATA_FILE, "r") as f: + data = json.load(f) + # Backward compatibility: if it's a list, treat as containers only + if isinstance(data, list): + return {"containers": data, "networks": []} + return data + except (FileNotFoundError, json.JSONDecodeError): + return {"containers": [], "networks": []} + + @staticmethod + def _write_metadata(metadata): + with open(Deployer.METADATA_FILE, "w") as f: + json.dump(metadata, f, indent=2) + + @staticmethod + def _add_container_to_metadata(container_name): + metadata = Deployer._read_metadata() + if container_name not in metadata["containers"]: + metadata["containers"].append(container_name) + Deployer._write_metadata(metadata) + + @staticmethod + def _add_network_to_metadata(network_name): + metadata = Deployer._read_metadata() + if network_name not in metadata["networks"]: + metadata["networks"].append(network_name) + Deployer._write_metadata(metadata) + + @staticmethod + def _remove_all_containers_from_metadata(): + metadata = Deployer._read_metadata() + containers = metadata["containers"] + if containers: + try: + import docker + + client = docker.from_env() + for name in containers: + try: + container = client.containers.get(name) + container.remove(force=True) + logging.info(f"Container {name} removed via metadata.") + except Exception as e: + logging.warning(f"Could not remove container {name}: {e}") + except Exception as e: + logging.warning(f"Docker error during metadata removal: {e}") + metadata["containers"] = [] + Deployer._write_metadata(metadata) + + @staticmethod + def _remove_all_networks_from_metadata(): + metadata = Deployer._read_metadata() + networks = metadata["networks"] + if networks: + try: + import docker + + client = docker.from_env() + for name in networks: + try: + network = client.networks.get(name) + network.remove() + logging.info(f"Network {name} removed via metadata.") + except Exception as e: + logging.warning(f"Could not remove network {name}: {e}") + except Exception as e: + logging.warning(f"Docker error during network metadata removal: {e}") + metadata["networks"] = [] + Deployer._write_metadata(metadata) + def __init__(self, args): + """ + Initializes the Deployer with robust handling of environment and prefix logic using tags only. + - Only sets NEBULA_ENV_TAG, NEBULA_PREFIX_TAG, and NEBULA_USER_TAG in the .env file. + - All logic and naming use tag helpers and tag variables. + - Defaults for prefix and production are consistent with main.py. + """ + # Prevent running NEBULA twice by checking metadata file + if os.path.exists(self.METADATA_FILE): + try: + with open(self.METADATA_FILE) as f: + data = json.load(f) + if (isinstance(data, dict) and (data.get("containers") or data.get("networks"))) or ( + isinstance(data, list) and data + ): + warning_msg = ( + "\n\033[91mERROR: NEBULA appears to be already running or was not cleanly shut down. " + "Please stop the existing instance or remove the metadata file before starting a new one.\033[0m\n" + "You can use 'docker ps -a --filter name={deployment_prefix}' to see the containers." + ) + logging.exception(warning_msg) + sys.exit(1) + except Exception: + warning_msg = ( + "\n\033[91mERROR: NEBULA metadata file is corrupt or unreadable. " + "Please remove or fix the file before starting a new instance.\033[0m\n" + ) + logging.exception(warning_msg) + sys.exit(1) + + self.configure_logger() + self.credentialmanager = CredentialManager() + self.credentialmanager.check_all_credentials() + + # --- Tag logic: CLI args > environment > fallback --- + arg_production = getattr(args, "production", False) + arg_prefix = getattr(args, "prefix", "dev") + arg_user = os.environ.get("USER", "unknown") + + env_tag = os.environ.get("NEBULA_ENV_TAG") + prefix_tag = os.environ.get("NEBULA_PREFIX_TAG") + user_tag = os.environ.get("NEBULA_USER_TAG", arg_user) + + self.env_tag = ("prod" if arg_production else "dev") if env_tag is None else env_tag + self.prefix_tag = arg_prefix if arg_prefix else (prefix_tag if prefix_tag else "dev") + self.user_tag = user_tag + + FileUtils.update_env_file(getattr(args, "env", ".env"), "NEBULA_ENV_TAG", self.env_tag) + FileUtils.update_env_file(getattr(args, "env", ".env"), "NEBULA_PREFIX_TAG", self.prefix_tag) + FileUtils.update_env_file(getattr(args, "env", ".env"), "NEBULA_USER_TAG", self.user_tag) + + self.production = self.env_tag == "prod" + self.prefix = self.prefix_tag + + deployment_prefix = f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_" + if DockerUtils.check_docker_by_prefix(deployment_prefix): + warning_msg = ( + f"\n\033[91mERROR: Found existing Docker containers with prefix '{deployment_prefix}'. " + f"NEBULA cannot be deployed with the same prefix. " + f"Please stop/remove existing containers before starting a new deployment.\033[0m\n" + f"You can use 'docker ps -a --filter name={deployment_prefix}' to see the containers." + ) + logging.exception(warning_msg) + sys.exit(1) + self.controller_port = int(args.controllerport) if hasattr(args, "controllerport") else 5050 + self.federation_controller_port = int(args.federationcontrollerport) if hasattr(args, "federationcontrollerport") else 5052 self.waf_port = int(args.wafport) if hasattr(args, "wafport") else 6000 self.frontend_port = int(args.webport) if hasattr(args, "webport") else 6060 self.grafana_port = int(args.grafanaport) if hasattr(args, "grafanaport") else 6040 @@ -420,7 +648,6 @@ def __init__(self, args): self.dev = args.developement if hasattr(args, "developement") else False self.advanced_analytics = args.advanced_analytics if hasattr(args, "advanced_analytics") else False self.databases_dir = args.databases if hasattr(args, "databases") else "/nebula/app/databases" - self.simulation = args.simulation self.config_dir = args.config self.log_dir = args.logs self.env_path = args.env @@ -430,15 +657,60 @@ def __init__(self, args): else os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) self.host_platform = "windows" if sys.platform == "win32" else "unix" - self.controller_host = f"{os.environ['USER']}_nebula-controller" self.gpu_available = False - self.configure_logger() + + self.controller_host = self.get_container_name("nebula-controller") + self.controller_port = int(args.controllerport) if hasattr(args, "controllerport") else 5050 + self.waf_port = int(args.wafport) if hasattr(args, "wafport") else 6000 + self.frontend_port = int(args.webport) if hasattr(args, "webport") else 6060 + self.grafana_port = int(args.grafanaport) if hasattr(args, "grafanaport") else 6040 + self.loki_port = int(args.lokiport) if hasattr(args, "lokiport") else 6010 + self.statistics_port = int(args.statsport) if hasattr(args, "statsport") else 8080 + + + def get_container_name(self, role_tag: str) -> str: + """ + Generate a standardized container name using tags. + Args: + role_tag (str): The component role (e.g., 'nebula-controller'). + Returns: + str: The composed container name. + """ + return f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_{role_tag}" + + def get_network_name(self, suffix: str) -> str: + """ + Generate a standardized network name using tags. + Args: + suffix (str): Suffix for the network (default: 'net-base'). + Returns: + str: The composed network name. + """ + return f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_{suffix}" + + @property + def deployment_prefix(self): + """ + Returns the deployment prefix for the current deployment. + + This property is used to prefix the names of the containers and networks + in the deployment. + + Returns: + str: The deployment prefix, either "production" or "dev". + + Typical use cases: + - Prefixing container and network names in the deployment. + - Ensuring consistent naming conventions across different environments. + + """ + return self.prefix def configure_logger(self): """ Configures the logging system for the deployment controller. - This method sets up both console and file logging with a consistent format and appropriate log levels. + This method sets up both console and file logging with a consistent format and appropriate log levels. It also ensures that Uvicorn loggers are properly configured to avoid duplicate log outputs. Main functionalities: @@ -452,7 +724,7 @@ def configure_logger(self): - Ensures clean and consistent logging output during deployment. Note: - This method does not set up file logging directly, but prepares the base configuration + This method does not set up file logging directly, but prepares the base configuration and Uvicorn logger behavior for further logging use. """ log_console_format = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s" @@ -475,7 +747,7 @@ def ensure_directory_access(self, directory_path: str) -> str: """ Ensures that the specified directory exists and is writable. - This method attempts to create the directory if it does not exist and verifies + This method attempts to create the directory if it does not exist and verifies write access by writing and deleting a temporary metadata file. Args: @@ -514,15 +786,17 @@ def ensure_directory_access(self, directory_path: str) -> str: except Exception as e: logging.exception(f"Failed to create/access directory {directory_path}: {str(e)}") - logging.exception("Please check directory permissions or choose a different location using --database option") + logging.exception( + "Please check directory permissions or choose a different location using --database option" + ) raise SystemExit(1) from e def start(self): """ Starts the NEBULA deployment process and all associated services. - This method initializes the NEBULA platform by setting up the environment, - checking port availability, starting key services (controller, frontend, WAF), + This method initializes the NEBULA platform by setting up the environment, + checking port availability, starting key services (controller, frontend, WAF), and launching a filesystem observer to react to configuration changes. Main functionalities: @@ -535,14 +809,14 @@ def start(self): - Handles system signals for clean shutdown. Typical use cases: - - Used to launch NEBULA in production, development, or simulation environments. + - Used to launch NEBULA in production or development environments. - Central entry point for managing NEBULA components during deployment. Note: - The method blocks indefinitely until manually interrupted, + The method blocks indefinitely until manually interrupted, and ensures graceful shutdown upon receiving SIGINT or SIGTERM. """ - banner = """ + banner = f""" ███╗ ██╗███████╗██████╗ ██╗ ██╗██╗ █████╗ ████╗ ██║██╔════╝██╔══██╗██║ ██║██║ ██╔══██╗ ██╔██╗ ██║█████╗ ██████╔╝██║ ██║██║ ███████║ @@ -558,6 +832,8 @@ def start(self): • Fernando Torres Vega https://nebula-dfl.com / https://nebula-dfl.eu + + [{"Production" if self.production else "Development"} mode] [{self.deployment_prefix} prefix] """ print("\x1b[0;36m" + banner + "\x1b[0m") @@ -578,6 +854,9 @@ def start(self): if not SocketUtils.is_port_open(self.controller_port): self.controller_port = SocketUtils.find_free_port(start_port=self.controller_port) + if not SocketUtils.is_port_open(self.federation_controller_port): + self.federation_controller_port = SocketUtils.find_free_port(start_port=self.federation_controller_port) + if not SocketUtils.is_port_open(self.frontend_port): self.frontend_port = SocketUtils.find_free_port(start_port=self.frontend_port) @@ -586,10 +865,12 @@ def start(self): self.run_controller() logging.info("NEBULA Controller is running") - logging.info(f"NEBULA Databases created in {self.databases_dir}") + self.run_database() + logging.info(f"NEBULA Databases docker is running") self.run_frontend() logging.info(f"NEBULA Frontend is running at http://localhost:{self.frontend_port}") - if self.production: + if self.production and self.prefix == "production": + logging.info("Deploying NEBULA WAF in production mode") self.run_waf() logging.info("NEBULA WAF is running") @@ -616,8 +897,8 @@ def signal_handler(self, sig, frame): """ Handles system termination signals to ensure a clean shutdown. - This method is triggered when the application receives SIGTERM or SIGINT signals - (e.g., via Ctrl+C or `kill`). It logs the event, performs cleanup actions, and + This method is triggered when the application receives SIGTERM or SIGINT signals + (e.g., via Ctrl+C or `kill`). It logs the event, performs cleanup actions, and terminates the process gracefully. Args: @@ -639,19 +920,19 @@ def signal_handler(self, sig, frame): def run_frontend(self): """ - Runs the Nebula controller within a Docker container, ensuring the required Docker environment is available. + Runs the NEBULA controller within a Docker container, ensuring the required Docker environment is available. This method: - Checks if Docker is running by verifying the Docker socket presence (platform-dependent). - - Creates a dedicated Docker network for the Nebula system. + - Creates a dedicated Docker network for the NEBULA system. - Configures environment variables, volume mounts, ports, and network settings for the container. - - Creates and starts the Nebula controller Docker container with the specified configuration. + - Creates and starts the NEBULA controller Docker container with the specified configuration. Raises: Exception: If Docker is not running or Docker Compose is not installed. Typical use cases: - - Launching the Nebula controller as part of the federated learning infrastructure. + - Launching the NEBULA controller as part of the federated learning infrastructure. - Ensuring proper Docker networking and environment setup for container execution. Note: @@ -668,17 +949,20 @@ def run_frontend(self): "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." ) - network_name = f"{os.environ['USER']}_nebula-net-base" + network_name = self.get_network_name("net-base") # Create the Docker network base = DockerUtils.create_docker_network(network_name) + Deployer._add_network_to_metadata(network_name) client = docker.from_env() environment = { - "NEBULA_CONTROLLER_NAME": os.environ["USER"], + "SECRET_KEY": os.environ.get("SECRET_KEY"), "NEBULA_PRODUCTION": self.production, - "NEBULA_ADVANCED_ANALYTICS": self.advanced_analytics, + "NEBULA_ENV_TAG": self.env_tag, + "NEBULA_PREFIX_TAG": self.prefix_tag, + "NEBULA_USER_TAG": self.user_tag, "NEBULA_FRONTEND_LOG": "/nebula/app/logs/frontend.log", "NEBULA_LOGS_DIR": "/nebula/app/logs/", "NEBULA_CONFIG_DIR": "/nebula/app/config/", @@ -709,9 +993,19 @@ def run_frontend(self): f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.100") }) + frontend_container_name = self.get_container_name("nebula-frontend") + + try: + existing = client.containers.get(frontend_container_name) + logging.warning( + f"Container {frontend_container_name} already exists. Deployment may fail or cause conflicts." + ) + except docker.errors.NotFound: + pass # No conflict, safe to proceed + container_id = client.api.create_container( image="nebula-frontend", - name=f"{os.environ['USER']}_nebula-frontend", + name=frontend_container_name, detach=True, environment=environment, volumes=volumes, @@ -721,20 +1015,87 @@ def run_frontend(self): ) client.api.start(container_id) + # Add to metadata + Deployer._add_container_to_metadata(frontend_container_name) - @staticmethod - def stop_frontend(): + def run_database(self): """ - Stops and removes all NEBULA frontend Docker containers associated with the current user. + Runs the Nebula database within a Docker container, ensuring the required Docker environment is available. + """ + if sys.platform == "win32": + if not os.path.exists("//./pipe/docker_Engine"): + raise Exception( + "Docker is not running, please check if Docker is running and Docker Compose is installed." + ) + else: + if not os.path.exists("/var/run/docker.sock"): + raise Exception( + "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." + ) - Responsibilities: - - Detects running Docker containers with names starting with '_nebula-frontend'. - - Gracefully stops and removes these frontend containers. + network_name = self.get_network_name("net-base") + base = DockerUtils.create_docker_network(network_name) + Deployer._add_network_to_metadata(network_name) - Typical use cases: - - Cleaning up frontend containers during shutdown or redeployment processes. - """ - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-frontend") + client = docker.from_env() + + # --- PostgreSQL --- + pg_container_name = self.get_container_name("nebula-database") + pg_environment = { + "POSTGRES_USER": "nebula", + "POSTGRES_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), + "POSTGRES_DB": "nebula", + "NEBULA_DATABASE_LOG": "/nebula/app/logs/database.log", + "DB_HOST": "localhost", + "DB_PORT": 5432, + "DB_USER": "nebula", + "DB_PASSWORD": os.environ.get("POSTGRES_PASSWORD"), + "NEBULA_ADMIN_PASSWORD": os.environ.get("NEBULA_ADMIN_PASSWORD") + } + host_sql_path = os.path.join(self.root_path, "nebula/database/adapters/postgress/docker/init-configs.sql") + db_data_path = os.path.join(self.databases_dir, "postgres-data") + os.makedirs(db_data_path, exist_ok=True) + + pg_host_config = client.api.create_host_config( + binds=[ + f"{self.root_path}:/nebula", + f"{host_sql_path}:/docker-entrypoint-initdb.d/init-configs.sql", + f"{db_data_path}:/var/lib/postgresql/data", + ], + port_bindings={5432: 5432, 5051: 5051}, + ) + pg_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.125")} + ) + + pg_container = client.api.create_container( + image="nebula-database", + name=pg_container_name, + detach=True, + environment=pg_environment, + host_config=pg_host_config, + networking_config=pg_networking_config, + ports=[5432, 5051], + ) + client.api.start(pg_container) + Deployer._add_container_to_metadata(pg_container_name) + + # --- PGWeb --- + pgweb_container_name = self.get_container_name("nebula-pgweb") + pgweb_host_config = client.api.create_host_config(port_bindings={8081: 8085}) + pgweb_networking_config = client.api.create_networking_config( + {f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.135")} + ) + + pgweb_container = client.api.create_container( + image="nebula-pgweb", + name=pgweb_container_name, + detach=True, + host_config=pgweb_host_config, + networking_config=pgweb_networking_config, + ) + client.api.start(pgweb_container) + Deployer._add_container_to_metadata(pgweb_container_name) def run_controller(self): if sys.platform == "win32": @@ -748,8 +1109,8 @@ def run_controller(self): "/var/run/docker.sock not found, please check if Docker is running and Docker Compose is installed." ) - network_name = f"{os.environ['USER']}_nebula-net-base" - + network_name = self.get_network_name("net-base") + try: subprocess.check_call(["nvidia-smi"]) self.gpu_available = True @@ -758,37 +1119,45 @@ def run_controller(self): # Create the Docker network base = DockerUtils.create_docker_network(network_name) + Deployer._add_network_to_metadata(network_name) client = docker.from_env() environment = { "USER": os.environ["USER"], - "NEBULA_PRODUCTION": self.production, + "NEBULA_ENV_TAG": self.env_tag, + "NEBULA_PREFIX_TAG": self.prefix_tag, + "NEBULA_USER_TAG": self.user_tag, "NEBULA_ROOT_HOST": self.root_path, - "NEBULA_ADVANCED_ANALYTICS": self.advanced_analytics, "NEBULA_DATABASES_DIR": "/nebula/app/databases", "NEBULA_CONTROLLER_LOG": "/nebula/app/logs/controller.log", + "NEBULA_FEDERATION_CONTROLLER_LOG": "/nebula/app/logs/federation.log", "NEBULA_CONFIG_DIR": "/nebula/app/config/", "NEBULA_LOGS_DIR": "/nebula/app/logs/", "NEBULA_CERTS_DIR": "/nebula/app/certs/", "NEBULA_HOST_PLATFORM": self.host_platform, "NEBULA_CONTROLLER_PORT": self.controller_port, + "NEBULA_FEDERATION_CONTROLLER_PORT" : self.federation_controller_port, "NEBULA_CONTROLLER_HOST": self.controller_host, "NEBULA_FRONTEND_PORT": self.frontend_port, + "NEBULA_DATABASE_API_URL": f"http://{self.get_container_name('nebula-database')}:5051" } volumes = ["/nebula", "/var/run/docker.sock"] - ports = [self.controller_port] + ports = [self.controller_port, self.federation_controller_port] host_config = client.api.create_host_config( binds=[ f"{self.root_path}:/nebula", "/var/run/docker.sock:/var/run/docker.sock", - f"{self.databases_dir}:/nebula/app/databases" + f"{self.databases_dir}:/nebula/app/databases", ], extra_hosts={"host.docker.internal": "host-gateway"}, - port_bindings={self.controller_port: self.controller_port}, + port_bindings={ + self.controller_port: self.controller_port, + self.federation_controller_port: self.federation_controller_port + }, device_requests=[{ "Driver": "nvidia", "Count": -1, @@ -800,9 +1169,19 @@ def run_controller(self): f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.150") }) + controller_container_name = self.get_container_name("nebula-controller") + + try: + existing = client.containers.get(controller_container_name) + logging.warning( + f"Container {controller_container_name} already exists. Deployment may fail or cause conflicts." + ) + except docker.errors.NotFound: + pass # No conflict, safe to proceed + container_id = client.api.create_container( image="nebula-controller", - name=f"{os.environ['USER']}_nebula-controller", + name=controller_container_name, detach=True, environment=environment, volumes=volumes, @@ -812,21 +1191,8 @@ def run_controller(self): ) client.api.start(container_id) - - @staticmethod - def stop_controller(): - """ - Stops all running Docker containers with names starting with '_nebula-controller'. - - Responsibilities: - - Initiates shutdown of all participant nodes related to the scenario. - - Gracefully stops and removes controller containers to ensure clean shutdown. - - Typical use cases: - - Used when stopping or restarting the Nebula controller service. - """ - ScenarioManagement.stop_participants() - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-controller") + # Add to metadata + Deployer._add_container_to_metadata(controller_container_name) def run_waf(self): """ @@ -841,11 +1207,12 @@ def run_waf(self): - Assigns static IP addresses to all containers within the created Docker network for consistent communication. Typical use cases: - - Deploying an integrated WAF solution alongside monitoring and logging components in the Nebula system. + - Deploying an integrated WAF solution alongside monitoring and logging components in the NEBULA system. - Ensuring comprehensive security monitoring and log management through containerized services. """ - network_name = f"{os.environ['USER']}_nebula-net-base" + network_name = self.get_network_name("net-base") base = DockerUtils.create_docker_network(network_name) + Deployer._add_network_to_metadata(network_name) client = docker.from_env() @@ -863,9 +1230,17 @@ def run_waf(self): f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.200") }) + waf_container_name = self.get_container_name("nebula-waf") + + try: + existing = client.containers.get(waf_container_name) + logging.warning(f"Container {waf_container_name} already exists. Deployment may fail or cause conflicts.") + except docker.errors.NotFound: + pass # No conflict, safe to proceed + container_id_waf = client.api.create_container( image="nebula-waf", - name=f"{os.environ['USER']}_nebula-waf", + name=waf_container_name, detach=True, volumes=volumes_waf, host_config=host_config_waf, @@ -874,9 +1249,10 @@ def run_waf(self): ) client.api.start(container_id_waf) + Deployer._add_container_to_metadata(waf_container_name) environment = { - "GF_SECURITY_ADMIN_PASSWORD": "admin", + "GF_SECURITY_ADMIN_PASSWORD": os.environ.get("GF_SECURITY_ADMIN_PASSWORD"), "GF_USERS_ALLOW_SIGN_UP": "false", "GF_SERVER_HTTP_PORT": "3000", "GF_SERVER_PROTOCOL": "http", @@ -897,9 +1273,19 @@ def run_waf(self): f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.201") }) + waf_grafana_container_name = self.get_container_name("nebula-waf-grafana") + + try: + existing = client.containers.get(waf_grafana_container_name) + logging.warning( + f"Container {waf_grafana_container_name} already exists. Deployment may fail or cause conflicts." + ) + except docker.errors.NotFound: + pass # No conflict, safe to proceed + container_id = client.api.create_container( image="nebula-waf-grafana", - name=f"{os.environ['USER']}_nebula-waf-grafana", + name=waf_grafana_container_name, detach=True, environment=environment, host_config=host_config, @@ -908,6 +1294,7 @@ def run_waf(self): ) client.api.start(container_id) + Deployer._add_container_to_metadata(waf_grafana_container_name) command = ["-config.file=/mnt/config/loki-config.yml"] @@ -921,9 +1308,19 @@ def run_waf(self): f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.202") }) + waf_loki_container_name = self.get_container_name("nebula-waf-loki") + + try: + existing = client.containers.get(waf_loki_container_name) + logging.warning( + f"Container {waf_loki_container_name} already exists. Deployment may fail or cause conflicts." + ) + except docker.errors.NotFound: + pass # No conflict, safe to proceed + container_id_loki = client.api.create_container( image="nebula-waf-loki", - name=f"{os.environ['USER']}_nebula-waf-loki", + name=waf_loki_container_name, detach=True, command=command, host_config=host_config_loki, @@ -932,6 +1329,7 @@ def run_waf(self): ) client.api.start(container_id_loki) + Deployer._add_container_to_metadata(waf_loki_container_name) volumes_promtail = ["/var/log/nginx"] @@ -945,9 +1343,19 @@ def run_waf(self): f"{network_name}": client.api.create_endpoint_config(ipv4_address=f"{base}.203") }) + waf_promtail_container_name = self.get_container_name("nebula-waf-promtail") + + try: + existing = client.containers.get(waf_promtail_container_name) + logging.warning( + f"Container {waf_promtail_container_name} already exists. Deployment may fail or cause conflicts." + ) + except docker.errors.NotFound: + pass # No conflict, safe to proceed + container_id_promtail = client.api.create_container( image="nebula-waf-promtail", - name=f"{os.environ['USER']}_nebula-waf-promtail", + name=waf_promtail_container_name, detach=True, volumes=volumes_promtail, host_config=host_config_promtail, @@ -955,48 +1363,79 @@ def run_waf(self): ) client.api.start(container_id_promtail) + Deployer._add_container_to_metadata(waf_promtail_container_name) @staticmethod - def stop_waf(): - """ - Stops all running Docker containers with names starting with '_nebula-waf'. - - Responsibilities: - - Gracefully shuts down and removes all WAF-related containers for the current user. - - Typical use cases: - - Cleaning up WAF containers during shutdown or redeployment of the Nebula system. - """ - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_nebula-waf") + def stop_deployer(): + if os.path.exists(Deployer.DEPLOYER_PID_FILE): + try: + with open(Deployer.DEPLOYER_PID_FILE) as f: + pid = int(f.read()) + os.remove(Deployer.DEPLOYER_PID_FILE) + # Check if process still exists before trying to kill it + if psutil.pid_exists(pid): + os.kill(pid, signal.SIGKILL) + logging.info(f"Deployer process {pid} terminated") + else: + logging.info(f"Deployer process {pid} already terminated") + except (ValueError, OSError) as e: + logging.warning(f"Error stopping deployer process: {e}") + except Exception as e: + logging.warning(f"Unexpected error stopping deployer process: {e}") @staticmethod def stop_all(): """ - Stops all running Nebula-related Docker containers and networks, then terminates the deployer process. + Stops all running NEBULA-related Docker containers and networks, then terminates the deployer process. Responsibilities: - Stops frontend, controller, and WAF containers for the current user. - - Removes all Docker containers and networks with names starting with the user's prefix. + - Removes all Docker containers tracked in the metadata file. - Reads and kills the deployer process using its PID file. - Exits the system cleanly, handling any exceptions during shutdown. Typical use cases: - - Full shutdown and cleanup of all Nebula components and resources on the host system. + - Full shutdown and cleanup of all NEBULA components and resources on the host system. """ print("Closing NEBULA (exiting from components)... Please wait") + errors = [] + try: - Deployer.stop_frontend() - Deployer.stop_controller() - Deployer.stop_waf() - DockerUtils.remove_containers_by_prefix(f"{os.environ['USER']}_") - DockerUtils.remove_docker_networks_by_prefix(f"{os.environ['USER']}_") - deployer_pid_file = os.path.join(os.path.dirname(__file__), "deployer.pid") - with open(deployer_pid_file) as f: - pid = int(f.read()) - os.remove(deployer_pid_file) - os.kill(pid, signal.SIGKILL) - sys.exit(0) + # Remove all scenario containers + ScenarioManagement.cleanup_scenario_containers() except Exception as e: - print(f"Nebula is closed with errors {e}") - finally: - sys.exit(0) + errors.append(f"Scenario cleanup error: {e}") + logging.warning(f"Error during scenario cleanup: {e}") + + try: + Deployer._remove_all_containers_from_metadata() + except Exception as e: + errors.append(f"Container cleanup error: {e}") + logging.warning(f"Error during container cleanup: {e}") + + try: + Deployer._remove_all_networks_from_metadata() + except Exception as e: + errors.append(f"Network cleanup error: {e}") + logging.warning(f"Error during network cleanup: {e}") + + try: + # Remove the metadata file after cleanup + if os.path.exists(Deployer.METADATA_FILE): + os.remove(Deployer.METADATA_FILE) + except Exception as e: + errors.append(f"Metadata file removal error: {e}") + logging.warning(f"Error removing metadata file: {e}") + + try: + Deployer.stop_deployer() + except Exception as e: + errors.append(f"Deployer stop error: {e}") + logging.warning(f"Error stopping deployer: {e}") + + if errors: + print(f"NEBULA is closed with errors: {'; '.join(errors)}") + else: + print("NEBULA closed successfully") + + sys.exit(0) diff --git a/app/main.py b/app/main.py index 4c4de4f70..5d14308af 100755 --- a/app/main.py +++ b/app/main.py @@ -17,6 +17,14 @@ help="Controller port (default: 5050)", ) +argparser.add_argument( + "-fcp", + "--federationcontrollerport", + dest="federationcontrollerport", + default=5052, + help="federation controller port port (default: 5052)", +) + argparser.add_argument( "--grafanaport", dest="grafanaport", @@ -64,14 +72,12 @@ help="Stop NEBULA platform or nodes only (use '--stop nodes' to stop only the nodes)", ) -argparser.add_argument("-s", "--simulation", action="store_false", dest="simulation", help="Run simulation") - argparser.add_argument( "-c", "--config", dest="config", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "config"), - help="Config directory path", + help="NEBULA config directory path", ) argparser.add_argument( @@ -79,7 +85,7 @@ "--database", dest="databases", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "databases"), - help="Nebula databases path", + help="NEBULA databases directory path", ) argparser.add_argument( @@ -87,7 +93,7 @@ "--logs", dest="logs", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs"), - help="Logs directory path", + help="NEBULA logs directory path", ) argparser.add_argument( @@ -95,7 +101,7 @@ "--certs", dest="certs", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "certs"), - help="Certs directory path", + help="NEBULA certs directory path", ) argparser.add_argument( @@ -106,24 +112,21 @@ help=".env file path", ) -argparser.add_argument("-dev", "--developement", dest="developement", default=True, help="Nebula for devs") - argparser.add_argument( "-p", "--production", dest="production", action="store_true", default=False, - help="Production mode", + help="Deploy NEBULA in production mode", ) argparser.add_argument( - "-ad", - "--advanced", - dest="advanced_analytics", - action="store_true", - default=False, - help="Advanced analytics", + "-pr", + "--prefix", + dest="prefix", + default="dev", + help="Deploy NEBULA components with a prefix", ) argparser.add_argument( diff --git a/nebula/addons/attacks/attacks.py b/nebula/addons/attacks/attacks.py index f5587f365..3d7bd71a7 100755 --- a/nebula/addons/attacks/attacks.py +++ b/nebula/addons/attacks/attacks.py @@ -129,7 +129,7 @@ def create_attack(engine) -> Attack: } # Get attack name and parameters from the engine configuration - attack_params = engine.config.participant["adversarial_args"].get("attack_params", {}) + attack_params = engine.config.participant["addons"]["adversarial_args"].get("attack_params", {}) attack_name = attack_params.get("attacks", None) if attack_name is None: raise AttackException("No attack specified") diff --git a/nebula/addons/gps/nebulagps.py b/nebula/addons/gps/nebulagps.py index 571a5ab53..b1025cf22 100644 --- a/nebula/addons/gps/nebulagps.py +++ b/nebula/addons/gps/nebulagps.py @@ -74,8 +74,8 @@ async def is_running(self): return self._running.is_set() async def get_geoloc(self): - latitude = self._config.participant["mobility_args"]["latitude"] - longitude = self._config.participant["mobility_args"]["longitude"] + latitude = self._config.participant["addons"]["mobility"]["latitude"] + longitude = self._config.participant["addons"]["mobility"]["longitude"] return (latitude, longitude) async def calculate_distance(self, self_lat, self_long, other_lat, other_long): diff --git a/nebula/addons/mobility.py b/nebula/addons/mobility.py index b46f7fe88..fdf4265a4 100755 --- a/nebula/addons/mobility.py +++ b/nebula/addons/mobility.py @@ -57,18 +57,18 @@ def __init__(self, config, verbose=False): self._mobility_task = None # Track the background task # Mobility configuration - self.mobility = self.config.participant["mobility_args"]["mobility"] - self.mobility_type = self.config.participant["mobility_args"]["mobility_type"] - self.grace_time = self.config.participant["mobility_args"]["grace_time_mobility"] - self.period = self.config.participant["mobility_args"]["change_geo_interval"] + self.mobility = self.config.participant["addons"]["mobility"]["enabled"] + self.mobility_type = self.config.participant["addons"]["mobility"]["mobility_type"] + self.grace_time = self.config.participant["addons"]["mobility"]["grace_time_mobility"] + self.period = self.config.participant["addons"]["mobility"]["change_geo_interval"] # INFO: These values may change according to the needs of the federation self.max_distance_with_direct_connections = 150 # meters self.max_movement_random_strategy = 50 # meters self.max_movement_nearest_strategy = 50 # meters self.max_initiate_approximation = self.max_distance_with_direct_connections * 1.2 - self.radius_federation = float(config.participant["mobility_args"]["radius_federation"]) - self.scheme_mobility = config.participant["mobility_args"]["scheme_mobility"] - self.round_frequency = int(config.participant["mobility_args"]["round_frequency"]) + self.radius_federation = float(config.participant["addons"]["mobility"]["radius_federation"]) + self.scheme_mobility = config.participant["addons"]["mobility"]["scheme_mobility"] + self.round_frequency = int(config.participant["addons"]["mobility"]["round_frequency"]) # Logging box with mobility information mobility_msg = f"Mobility: {self.mobility}\nMobility type: {self.mobility_type}\nRadius federation: {self.radius_federation}\nScheme mobility: {self.scheme_mobility}\nEach {self.round_frequency} rounds" print_msg_box(msg=mobility_msg, indent=2, title="Mobility information") @@ -267,11 +267,11 @@ async def set_geo_location(self, latitude, longitude): if latitude < -90 or latitude > 90 or longitude < -180 or longitude > 180: # If the new location is out of bounds, we keep the old location - latitude = self.config.participant["mobility_args"]["latitude"] - longitude = self.config.participant["mobility_args"]["longitude"] + latitude = self.config.participant["addons"]["mobility"]["latitude"] + longitude = self.config.participant["addons"]["mobility"]["longitude"] - self.config.participant["mobility_args"]["latitude"] = latitude - self.config.participant["mobility_args"]["longitude"] = longitude + self.config.participant["addons"]["mobility"]["latitude"] = latitude + self.config.participant["addons"]["mobility"]["longitude"] = longitude if self._verbose: logging.info(f"📍 New geo location: {latitude}, {longitude}") cle = ChangeLocationEvent(latitude, longitude) @@ -301,8 +301,8 @@ async def change_geo_location(self): """ if self.mobility and (self.mobility_type == "topology" or self.mobility_type == "both"): random.seed(time.time() + self.config.participant["device_args"]["idx"]) - latitude = float(self.config.participant["mobility_args"]["latitude"]) - longitude = float(self.config.participant["mobility_args"]["longitude"]) + latitude = float(self.config.participant["addons"]["mobility"]["latitude"]) + longitude = float(self.config.participant["addons"]["mobility"]["longitude"]) if True: # Get neighbor closer to me async with self._nodes_distances_lock: diff --git a/nebula/addons/networksimulation/nebulanetworksimulator.py b/nebula/addons/networksimulation/nebulanetworksimulator.py index 9af1f768e..85efbb2a2 100644 --- a/nebula/addons/networksimulation/nebulanetworksimulator.py +++ b/nebula/addons/networksimulation/nebulanetworksimulator.py @@ -35,7 +35,7 @@ def cm(self): async def start(self): logging.info("🌐 Nebula Network Simulator starting...") self._running.set() - grace_time = self.cm.config.participant["mobility_args"]["grace_time_mobility"] + grace_time = self.cm.config.participant["addons"]["mobility"]["grace_time_mobility"] # if self._verbose: logging.info(f"Waiting {grace_time}s to start applying network conditions based on distances between devices") # await asyncio.sleep(grace_time) await EventManager.get_instance().subscribe_addonevent( diff --git a/nebula/addons/reporter.py b/nebula/addons/reporter.py index 376f6a208..5b3952519 100755 --- a/nebula/addons/reporter.py +++ b/nebula/addons/reporter.py @@ -4,6 +4,7 @@ import logging import os import sys +from nebula.controller.federation.utils_requests import NodeUpdateRequest, NodeDoneRequest from typing import TYPE_CHECKING import aiohttp @@ -54,7 +55,7 @@ def __init__(self, config, trainer): self.frequency = self.config.participant["reporter_args"]["report_frequency"] self.grace_time = self.config.participant["reporter_args"]["grace_time_reporter"] self.data_queue = asyncio.Queue() - self.url = f"http://{self.config.participant['scenario_args']['controller']}/nodes/{self.config.participant['scenario_args']['name']}/update" + self.url = f"http://{self.config.participant['scenario_args']['controller']}/nodes/{self.config.participant['scenario_args']['federation_id']}/update" self.counter = 0 self.first_net_metrics = True @@ -170,8 +171,18 @@ async def report_scenario_finished(self): might be temporarily overloaded. - Logs exceptions if the connection attempt to the controller fails. """ - url = f"http://{self.config.participant['scenario_args']['controller']}/nodes/{self.config.participant['scenario_args']['name']}/done" - data = json.dumps({"idx": self.config.participant["device_args"]["idx"]}) + url = f"http://{self.config.participant['scenario_args']['controller']}/nodes/{self.config.participant['scenario_args']['federation_id']}/done" + node_done_req = NodeDoneRequest(idx=self.config.participant["device_args"]["idx"], + deployment=self.config.participant["scenario_args"]["deployment"], + name=self.config.participant["scenario_args"]["name"], + federation_id=self.config.participant["scenario_args"]["federation_id"] + ) + payload = node_done_req.model_dump() + data = json.dumps(payload) + # data = json.dumps({"idx": self.config.participant["device_args"]["idx"], + # "deployment": self.config.participant["scenario_args"]["deployment"], + # "name": self.config.participant["scenario_args"]["name"], + # "federation_id": self.config.participant["scenario_args"]["federation_id"]}) headers = { "Content-Type": "application/json", "User-Agent": f"NEBULA Participant {self.config.participant['device_args']['idx']}", @@ -263,11 +274,13 @@ async def __report_status_to_controller(self): - Delays for 5 seconds upon general exceptions to avoid rapid retry loops. """ try: + node_updt_req = NodeUpdateRequest(config=self.config.participant) + payload = node_updt_req.model_dump() async with ( aiohttp.ClientSession() as session, session.post( self.url, - data=json.dumps(self.config.participant), + data=json.dumps(payload), headers={ "Content-Type": "application/json", "User-Agent": f"NEBULA Participant {self.config.participant['device_args']['idx']}", diff --git a/nebula/addons/reputation/reputation.py b/nebula/addons/reputation/reputation.py index 561199513..9377b2d42 100644 --- a/nebula/addons/reputation/reputation.py +++ b/nebula/addons/reputation/reputation.py @@ -3,7 +3,7 @@ import time import numpy as np import torch - +from nebula.core.addonmanager import NebulaAddon from datetime import datetime from typing import TYPE_CHECKING from nebula.addons.functions import print_msg_box @@ -54,14 +54,14 @@ def __init__( self.similarity = [] -class Reputation: +class Reputation(NebulaAddon): """ Class to define and manage the reputation of a participant in the network. The class handles collection of metrics, calculation of static and dynamic reputation, updating history, and communication of reputation scores to neighbors. """ - + REPUTATION_THRESHOLD = 0.6 SIMILARITY_THRESHOLD = 0.6 INITIAL_ROUND_FOR_REPUTATION = 1 @@ -70,12 +70,12 @@ class Reputation: WEIGHTED_HISTORY_ROUNDS = 3 FRACTION_ANOMALY_MULTIPLIER = 1.20 THRESHOLD_ANOMALY_MULTIPLIER = 1.15 - + # Augmentation factors LATENCY_AUGMENT_FACTOR = 1.4 MESSAGE_AUGMENT_FACTOR_EARLY = 2.0 MESSAGE_AUGMENT_FACTOR_NORMAL = 1.1 - + # Penalty and decay factors HISTORICAL_PENALTY_THRESHOLD = 0.9 NEGATIVE_LATENCY_PENALTY = 0.3 @@ -104,7 +104,7 @@ def __init__(self, engine: "Engine", config: "Config"): self._addr = engine.addr self._log_dir = engine.log_dir self._idx = engine.idx - + self._initialize_data_structures() self._configure_constants() self._load_configuration() @@ -114,9 +114,9 @@ def __init__(self, engine: "Engine", config: "Config"): def _configure_constants(self): """Configure system constants from config or use defaults.""" - reputation_config = self._config.participant.get("defense_args", {}).get("reputation", {}) + reputation_config = self._config.participant.get("addons", {}).get("reputation", {}) constants_config = reputation_config.get("constants", {}) - + self.REPUTATION_THRESHOLD = constants_config.get("reputation_threshold", self.REPUTATION_THRESHOLD) self.SIMILARITY_THRESHOLD = constants_config.get("similarity_threshold", self.SIMILARITY_THRESHOLD) self.INITIAL_ROUND_FOR_REPUTATION = constants_config.get("initial_round_for_reputation", self.INITIAL_ROUND_FOR_REPUTATION) @@ -168,7 +168,7 @@ def _initialize_data_structures(self): def _load_configuration(self): """Load and validate reputation configuration.""" - reputation_config = self._config.participant["defense_args"]["reputation"] + reputation_config = self._config.participant["addons"]["reputation"] self._enabled = reputation_config["enabled"] self._metrics = reputation_config["metrics"] self._initial_reputation = float(reputation_config["initial_reputation"]) @@ -180,15 +180,15 @@ def _load_configuration(self): def _setup_connection_metrics(self): """Initialize metrics for each neighbor.""" - neighbors_str = self._config.participant["network_args"]["neighbors"] - for neighbor in neighbors_str.split(): + neighbors = self._config.participant["network_args"]["neighbors"] + for neighbor in neighbors: self.connection_metrics[neighbor] = Metrics() def _configure_metric_weights(self): """Configure weights for different metrics based on weighting factor.""" default_weight = 0.25 metric_names = ["model_arrival_latency", "model_similarity", "num_messages", "fraction_parameters_changed"] - + if self._weighting_factor == "static": self._weight_model_arrival_latency = float( self._metrics.get("model_arrival_latency", {}).get("weight", default_weight) @@ -209,7 +209,7 @@ def _configure_metric_weights(self): elif not isinstance(self._metrics[metric_name], dict): self._metrics[metric_name] = {"enabled": bool(self._metrics[metric_name])} self._metrics[metric_name]["weight"] = default_weight - + self._weight_model_arrival_latency = default_weight self._weight_model_similarity = default_weight self._weight_num_messages = default_weight @@ -229,24 +229,24 @@ def engine(self): def _is_metric_enabled(self, metric_name: str, metrics_config: dict = None) -> bool: """ Check if a specific metric is enabled based on the provided configuration. - + Args: metric_name (str): The name of the metric to check. - metrics_config (dict, optional): The configuration dictionary for metrics. + metrics_config (dict, optional): The configuration dictionary for metrics. If None, uses the instance's _metrics. - + Returns: bool: True if the metric is enabled, False otherwise. """ config_to_use = metrics_config if metrics_config is not None else getattr(self, '_metrics', None) - + if not isinstance(config_to_use, dict): if metrics_config is not None: logging.warning(f"metrics_config is not a dictionary: {type(metrics_config)}") else: logging.warning("_metrics is not properly initialized") return False - + metric_config = config_to_use.get(metric_name) if metric_config is None: return False @@ -269,7 +269,7 @@ def save_data( ): """ Save data between nodes and aggregated models. - + Args: type_data: Type of data to save ('number_message', 'fraction_of_params_changed', 'model_arrival_latency') nei: Neighbor identifier @@ -290,7 +290,7 @@ def save_data( try: metrics_instance = self.connection_metrics[nei] - + if type_data == "number_message": message_data = {"time": time, "current_round": current_round} if not isinstance(metrics_instance.messages, list): @@ -316,7 +316,7 @@ def save_data( except Exception: logging.exception(f"Error saving data for type {type_data} and neighbor {nei}") - async def setup(self): + async def start(self): """Set up the reputation system by subscribing to relevant events.""" if self._enabled: await EventManager.get_instance().subscribe_node_event(RoundStartEvent, self.on_round_start) @@ -340,24 +340,27 @@ async def setup(self): ) await EventManager.get_instance().subscribe_node_event(DuplicatedMessageEvent, self.recollect_duplicated_number_message) + async def stop(): + pass + async def init_reputation( self, federation_nodes=None, round_num=None, last_feedback_round=None, init_reputation=None ): """ Initialize the reputation system. - + Args: federation_nodes: List of federation node identifiers - round_num: Current round number + round_num: Current round number last_feedback_round: Last round that received feedback init_reputation: Initial reputation value to assign """ if not self._enabled: return - + if not self._validate_init_parameters(federation_nodes, round_num, init_reputation): return - + neighbors = self._validate_federation_nodes(federation_nodes) if not neighbors: logging.error("init_reputation | No valid neighbors found") @@ -370,13 +373,13 @@ def _validate_init_parameters(self, federation_nodes, round_num, init_reputation if not federation_nodes: logging.error("init_reputation | No federation nodes provided") return False - + if round_num is None: logging.warning("init_reputation | Round number not provided") - + if init_reputation is None: logging.warning("init_reputation | Initial reputation value not provided") - + return True async def _initialize_neighbor_reputations(self, neighbors: list, round_num: int, last_feedback_round: int, init_reputation: float): @@ -392,7 +395,7 @@ def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feed "round": round_num, "last_feedback_round": last_feedback_round, } - + if nei not in self.reputation: self.reputation[nei] = reputation_data elif self.reputation[nei].get("reputation") is None: @@ -401,21 +404,21 @@ def _create_or_update_reputation_entry(self, nei: str, round_num: int, last_feed def _validate_federation_nodes(self, federation_nodes) -> list: """ Validate and filter federation nodes. - + Args: federation_nodes: List of federation node identifiers - + Returns: list: List of valid node identifiers """ if not federation_nodes: return [] - + valid_nodes = [node for node in federation_nodes if node and str(node).strip()] - + if not valid_nodes: logging.warning("No valid federation nodes found after filtering") - + return valid_nodes async def _calculate_static_reputation( @@ -429,7 +432,7 @@ async def _calculate_static_reputation( Args: addr: The participant's address - nei: The neighbor's address + nei: The neighbor's address metric_values: Dictionary with metric values """ static_weights = { @@ -440,10 +443,10 @@ async def _calculate_static_reputation( } reputation_static = sum( - metric_values.get(metric_name, 0) * static_weights[metric_name] + metric_values.get(metric_name, 0) * static_weights[metric_name] for metric_name in static_weights ) - + logging.info(f"Static reputation for node {nei} at round {await self.engine.get_round()}: {reputation_static}") avg_reputation = await self.save_reputation_history_in_memory(self.engine.addr, nei, reputation_static) @@ -476,48 +479,48 @@ async def _calculate_dynamic_reputation(self, addr, neighbors): async def _calculate_average_weights(self): """Calculate average weights for all enabled metrics.""" average_weights = {} - + for metric_name in self.history_data.keys(): if self._is_metric_enabled(metric_name): average_weights[metric_name] = await self._get_metric_average_weight(metric_name) - + return average_weights - + async def _get_metric_average_weight(self, metric_name): """Get the average weight for a specific metric.""" if metric_name not in self.history_data or not self.history_data[metric_name]: logging.debug(f"No history data available for metric: {metric_name}") return 0 - + valid_entries = [ entry for entry in self.history_data[metric_name] - if (entry.get("round") is not None and - entry["round"] >= await self._engine.get_round() and + if (entry.get("round") is not None and + entry["round"] >= await self._engine.get_round() and entry.get("weight") not in [None, -1]) ] - + if not valid_entries: return 0 - + try: weights = [entry["weight"] for entry in valid_entries if entry.get("weight") is not None] return sum(weights) / len(weights) if weights else 0 except (TypeError, ZeroDivisionError) as e: logging.warning(f"Error calculating average weight for {metric_name}: {e}") return 0 - + async def _process_neighbors_reputation(self, addr, neighbors, average_weights): """Process reputation calculation for all neighbors.""" for nei in neighbors: metric_values = await self._get_neighbor_metric_values(nei) - + if all(metric_name in metric_values for metric_name in average_weights): await self._update_neighbor_reputation(addr, nei, metric_values, average_weights) - + async def _get_neighbor_metric_values(self, nei): """Get metric values for a specific neighbor in the current round.""" metric_values = {} - + for metric_name in self.history_data: if self._is_metric_enabled(metric_name): for entry in self.history_data.get(metric_name, []): @@ -526,16 +529,16 @@ async def _get_neighbor_metric_values(self, nei): entry.get("nei") == nei): metric_values[metric_name] = entry.get("metric_value", 0) break - + return metric_values - + async def _update_neighbor_reputation(self, addr, nei, metric_values, average_weights): """Update reputation for a specific neighbor.""" reputation_with_weights = sum( - metric_values.get(metric_name, 0) * average_weights[metric_name] + metric_values.get(metric_name, 0) * average_weights[metric_name] for metric_name in average_weights ) - + logging.info( f"Dynamic reputation with weights for {nei} at round {await self._engine.get_round()}: {reputation_with_weights}" ) @@ -564,7 +567,7 @@ async def _update_reputation_record(self, nei: str, reputation: float, data: dic data: Additional data to update (currently unused) """ current_round = await self._engine.get_round() - + if nei not in self.reputation: self.reputation[nei] = { "reputation": reputation, @@ -576,7 +579,7 @@ async def _update_reputation_record(self, nei: str, reputation: float, data: dic self.reputation[nei]["round"] = current_round logging.info(f"Reputation of node {nei}: {self.reputation[nei]['reputation']}") - + if self.reputation[nei]["reputation"] < self.REPUTATION_THRESHOLD and current_round > 0: self.rejected_nodes.add(nei) logging.info(f"Rejected node {nei} at round {current_round}") @@ -608,23 +611,23 @@ def calculate_weighted_values( reputation_metrics ) self._add_current_metrics_to_history(active_metrics, history_data, current_round, addr, nei) - + if current_round >= self.INITIAL_ROUND_FOR_REPUTATION and len(active_metrics) > 0: adjusted_weights = self._calculate_dynamic_weights(active_metrics, history_data) else: adjusted_weights = self._calculate_uniform_weights(active_metrics) - + self._update_history_with_weights(active_metrics, history_data, adjusted_weights, current_round, nei) def _ensure_history_data_structure(self, history_data: dict): """Ensure all required keys exist in history data structure.""" required_keys = [ "num_messages", - "model_similarity", + "model_similarity", "fraction_parameters_changed", "model_arrival_latency", ] - + for key in required_keys: if key not in history_data: history_data[key] = [] @@ -644,7 +647,7 @@ def _get_active_metrics( "fraction_parameters_changed": fraction_score_asign, "model_arrival_latency": avg_model_arrival_latency, } - + return {k: v for k, v in all_metrics.items() if self._is_metric_enabled(k, reputation_metrics)} def _add_current_metrics_to_history(self, active_metrics: dict, history_data: dict, current_round: int, addr: str, nei: str): @@ -662,7 +665,7 @@ def _add_current_metrics_to_history(self, active_metrics: dict, history_data: di def _calculate_dynamic_weights(self, active_metrics: dict, history_data: dict) -> dict: """Calculate dynamic weights based on metric deviations.""" deviations = self._calculate_metric_deviations(active_metrics, history_data) - + if all(deviation == 0.0 for deviation in deviations.values()): return self._generate_random_weights(active_metrics) else: @@ -672,7 +675,7 @@ def _calculate_dynamic_weights(self, active_metrics: dict, history_data: dict) - def _calculate_metric_deviations(self, active_metrics: dict, history_data: dict) -> dict: """Calculate deviations of current metrics from historical means.""" deviations = {} - + for metric_name, current_value in active_metrics.items(): historical_values = history_data[metric_name] metric_values = [ @@ -680,11 +683,11 @@ def _calculate_metric_deviations(self, active_metrics: dict, history_data: dict) for entry in historical_values if "metric_value" in entry and entry["metric_value"] != 0 ] - + mean_value = np.mean(metric_values) if metric_values else 0 deviation = abs(current_value - mean_value) deviations[metric_name] = deviation - + return deviations def _generate_random_weights(self, active_metrics: dict) -> dict: @@ -692,7 +695,7 @@ def _generate_random_weights(self, active_metrics: dict) -> dict: num_metrics = len(active_metrics) random_weights = [random.random() for _ in range(num_metrics)] total_random_weight = sum(random_weights) - + return { metric_name: weight / total_random_weight for metric_name, weight in zip(active_metrics, random_weights, strict=False) @@ -702,14 +705,14 @@ def _normalize_deviation_weights(self, deviations: dict) -> dict: """Normalize weights based on deviations.""" max_deviation = max(deviations.values()) if deviations else 1 normalized_weights = { - metric_name: (deviation / max_deviation) + metric_name: (deviation / max_deviation) for metric_name, deviation in deviations.items() } - + total_weight = sum(normalized_weights.values()) if total_weight > 0: return { - metric_name: weight / total_weight + metric_name: weight / total_weight for metric_name, weight in normalized_weights.items() } else: @@ -720,20 +723,20 @@ def _adjust_weights_with_minimum(self, normalized_weights: dict, deviations: dic """Apply minimum weight constraints and renormalize.""" mean_deviation = np.mean(list(deviations.values())) dynamic_min_weight = max(self.DYNAMIC_MIN_WEIGHT_THRESHOLD, mean_deviation / (mean_deviation + 1)) - + adjusted_weights = {} total_adjusted_weight = 0 - + for metric_name, weight in normalized_weights.items(): adjusted_weight = max(weight, dynamic_min_weight) adjusted_weights[metric_name] = adjusted_weight total_adjusted_weight += adjusted_weight - + # Renormalize if total weight exceeds 1 if total_adjusted_weight > 1: for metric_name in adjusted_weights: adjusted_weights[metric_name] /= total_adjusted_weight - + return adjusted_weights def _calculate_uniform_weights(self, active_metrics: dict) -> dict: @@ -748,8 +751,8 @@ def _update_history_with_weights(self, active_metrics: dict, history_data: dict, for metric_name in active_metrics: weight = weights.get(metric_name, -1) for entry in history_data[metric_name]: - if (entry["metric_name"] == metric_name and - entry["round"] == current_round and + if (entry["metric_name"] == metric_name and + entry["round"] == current_round and entry["nei"] == nei): entry["weight"] = weight @@ -765,7 +768,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): try: current_round = await self._engine.get_round() metrics_instance = self.connection_metrics.get(nei) - + if not metrics_instance: logging.warning(f"No metrics found for neighbor {nei}") return self._get_default_metric_values() @@ -778,7 +781,7 @@ async def calculate_value_metrics(self, addr, nei, metrics_active=None): } self._log_metrics_graphics(metric_results, addr, nei, current_round) - + return ( metric_results["messages"]["avg"], metric_results["similarity"], @@ -802,7 +805,7 @@ def _process_num_messages_metric(self, metrics_instance, addr: str, nei: str, cu filtered_messages = [ msg for msg in metrics_instance.messages if msg.get("current_round") == current_round ] - + for msg in filtered_messages: self.messages_number_message.append({ "number_message": msg.get("time"), @@ -813,9 +816,9 @@ def _process_num_messages_metric(self, metrics_instance, addr: str, nei: str, cu normalized, count = self.manage_metric_number_message( self.messages_number_message, addr, nei, current_round, True ) - + avg = self.save_number_message_history(addr, nei, normalized, current_round) - + if avg is None and current_round > self.HISTORY_ROUNDS_LOOKBACK: avg = self.number_message_history[(addr, nei)][current_round - 1]["avg_number_message"] @@ -901,7 +904,7 @@ def _process_model_arrival_latency_metric(self, metrics_instance, addr: str, nei if avg_latency is None and current_round > 1: avg_latency = self.model_arrival_latency_history[(addr, nei)][current_round - 1]["score"] return avg_latency or 0 - + return 0 def _process_model_similarity_metric(self, nei: str, current_round: int, metrics_active) -> float: @@ -938,7 +941,7 @@ def create_graphics_to_metrics( ): """ Create and log graphics for reputation metrics. - + Args: number_message_count: Count of messages for logging number_message_norm: Normalized message metric @@ -952,25 +955,25 @@ def create_graphics_to_metrics( """ if current_round is None or current_round >= total_rounds: return - + self.engine.trainer._logger.log_data( - {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}}, + {f"R-Model_arrival_latency_reputation/{addr}": {nei: model_arrival_latency}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}}, + {f"R-Count_messages_number_message_reputation/{addr}": {nei: number_message_count}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-number_message_reputation/{addr}": {nei: number_message_norm}}, + {f"R-number_message_reputation/{addr}": {nei: number_message_norm}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Similarity_reputation/{addr}": {nei: similarity}}, + {f"R-Similarity_reputation/{addr}": {nei: similarity}}, step=current_round ) self.engine.trainer._logger.log_data( - {f"R-Fraction_reputation/{addr}": {nei: fraction}}, + {f"R-Fraction_reputation/{addr}": {nei: fraction}}, step=current_round ) @@ -991,7 +994,7 @@ def analyze_anomalies( try: key = (addr, nei, current_round) self._initialize_fraction_history_entry(key, fraction_changed, threshold) - + if current_round == 0: return self._handle_initial_round_anomalies(key, fraction_changed, threshold) else: @@ -1032,16 +1035,16 @@ def _handle_subsequent_round_anomalies( ) -> float: """Handle anomaly analysis for subsequent rounds.""" prev_stats = self._find_previous_valid_stats(addr, nei, current_round) - + if prev_stats is None: logging.warning(f"No valid previous stats found for {addr}, {nei}, round {current_round}") return 1.0 - + anomalies = self._detect_anomalies(fraction_changed, threshold, prev_stats) values = self._calculate_anomaly_values(fraction_changed, threshold, prev_stats, anomalies) fraction_score = self._calculate_combined_score(values) self._update_fraction_statistics(key, fraction_changed, threshold, prev_stats, anomalies, fraction_score) - + return max(fraction_score, 0) def _find_previous_valid_stats(self, addr: str, nei: str, current_round: int) -> dict: @@ -1049,18 +1052,18 @@ def _find_previous_valid_stats(self, addr: str, nei: str, current_round: int) -> for i in range(1, current_round + 1): candidate_key = (addr, nei, current_round - i) candidate_data = self.fraction_changed_history.get(candidate_key, {}) - + required_keys = ["mean_fraction", "std_dev_fraction", "mean_threshold", "std_dev_threshold"] if all(candidate_data.get(k) is not None for k in required_keys): return candidate_data - + return None def _detect_anomalies(self, current_fraction: float, current_threshold: float, prev_stats: dict) -> dict: """Detect if current values are anomalous compared to previous statistics.""" upper_mean_fraction = (prev_stats["mean_fraction"] + prev_stats["std_dev_fraction"]) * self.FRACTION_ANOMALY_MULTIPLIER upper_mean_threshold = (prev_stats["mean_threshold"] + prev_stats["std_dev_threshold"]) * self.THRESHOLD_ANOMALY_MULTIPLIER - + return { "fraction_anomaly": current_fraction > upper_mean_fraction, "threshold_anomaly": current_threshold > upper_mean_threshold, @@ -1074,19 +1077,19 @@ def _calculate_anomaly_values( """Calculate penalty values for fraction and threshold anomalies.""" fraction_value = 1.0 threshold_value = 1.0 - + if anomalies["fraction_anomaly"]: mean_fraction_prev = prev_stats["mean_fraction"] if mean_fraction_prev > 0: penalization_factor = abs(current_fraction - mean_fraction_prev) / mean_fraction_prev fraction_value = 1 - (1 / (1 + np.exp(-penalization_factor))) - + if anomalies["threshold_anomaly"]: mean_threshold_prev = prev_stats["mean_threshold"] if mean_threshold_prev > 0: penalization_factor = abs(current_threshold - mean_threshold_prev) / mean_threshold_prev threshold_value = 1 - (1 / (1 + np.exp(-penalization_factor))) - + return { "fraction_value": fraction_value, "threshold_value": threshold_value, @@ -1099,19 +1102,19 @@ def _calculate_combined_score(self, values: dict) -> float: return fraction_weight * values["fraction_value"] + threshold_weight * values["threshold_value"] def _update_fraction_statistics( - self, key: tuple, current_fraction: float, current_threshold: float, + self, key: tuple, current_fraction: float, current_threshold: float, prev_stats: dict, anomalies: dict, fraction_score: float ): """Update the fraction statistics for the current round.""" self.fraction_changed_history[key]["fraction_anomaly"] = anomalies["fraction_anomaly"] self.fraction_changed_history[key]["threshold_anomaly"] = anomalies["threshold_anomaly"] - + self.fraction_changed_history[key]["mean_fraction"] = (current_fraction + prev_stats["mean_fraction"]) / 2 self.fraction_changed_history[key]["mean_threshold"] = (current_threshold + prev_stats["mean_threshold"]) / 2 - + fraction_variance = ((current_fraction - prev_stats["mean_fraction"]) ** 2 + prev_stats["std_dev_fraction"] ** 2) / 2 threshold_variance = ((self.THRESHOLD_VARIANCE_MULTIPLIER * (current_threshold - prev_stats["mean_threshold"]) ** 2) + prev_stats["std_dev_threshold"] ** 2) / 2 - + self.fraction_changed_history[key]["std_dev_fraction"] = np.sqrt(fraction_variance) self.fraction_changed_history[key]["std_dev_threshold"] = np.sqrt(threshold_variance) self.fraction_changed_history[key]["fraction_score"] = fraction_score @@ -1132,9 +1135,9 @@ def manage_model_arrival_latency(self, addr, nei, latency, current_round, round_ """ try: current_key = nei - + self._initialize_latency_round_entry(current_round, current_key, latency) - + if current_round >= 1: score = self._calculate_latency_score(current_round, current_key, latency) self._update_latency_entry_with_score(current_round, current_key, score) @@ -1161,17 +1164,17 @@ def _calculate_latency_score(self, current_round: int, current_key: str, latency """Calculate the latency score based on historical data.""" target_round = self._get_target_round_for_latency(current_round) all_latencies = self._get_all_latencies_for_round(target_round) - + if not all_latencies: return 0.0 - + mean_latency = np.mean(all_latencies) augment_mean = mean_latency * self.LATENCY_AUGMENT_FACTOR - + if latency is None: logging.info(f"latency is None in round {current_round} for nei {current_key}") return -0.5 - + if latency <= augment_mean: return 1.0 else: @@ -1195,7 +1198,7 @@ def _update_latency_entry_with_score(self, current_round: int, current_key: str, target_round = self._get_target_round_for_latency(current_round) all_latencies = self._get_all_latencies_for_round(target_round) mean_latency = np.mean(all_latencies) if all_latencies else 0 - + self.model_arrival_latency_history[current_round][current_key].update({ "mean_latency": mean_latency, "score": score, @@ -1215,9 +1218,9 @@ def save_model_arrival_latency_history(self, nei, model_arrival_latency, round_n """ try: current_key = nei - + self._initialize_latency_history_entry(round_num, current_key, model_arrival_latency) - + if model_arrival_latency > 0 and round_num >= 1: avg_model_arrival_latency = self._calculate_latency_weighted_average_positive( round_num, current_key, model_arrival_latency @@ -1236,7 +1239,7 @@ def save_model_arrival_latency_history(self, nei, model_arrival_latency, round_n ) return avg_model_arrival_latency - + except Exception: logging.exception("Error saving model_arrival_latency history") @@ -1284,14 +1287,14 @@ def manage_metric_number_message( ) -> tuple[float, int]: """ Manage the number of messages metric for a specific neighbor. - + Args: messages_number_message: List of message data addr: Source address nei: Neighbor address current_round: Current round number metric_active: Whether the metric is active - + Returns: Tuple of (normalized_messages, messages_count) """ @@ -1301,13 +1304,13 @@ def manage_metric_number_message( messages_count = self._count_relevant_messages(messages_number_message, addr, nei, current_round) neighbor_stats = self._calculate_neighbor_statistics(messages_number_message, current_round) - + normalized_messages = self._calculate_normalized_messages(messages_count, neighbor_stats) - + normalized_messages = self._apply_historical_penalty( normalized_messages, addr, nei, current_round ) - + self._store_message_history(addr, nei, current_round, normalized_messages) normalized_messages = max(0.001, normalized_messages) @@ -1339,7 +1342,7 @@ def _calculate_neighbor_statistics(self, messages: list, current_round: int) -> neighbor_counts[key] = neighbor_counts.get(key, 0) + 1 counts_all_neighbors = list(neighbor_counts.values()) - + if not counts_all_neighbors: return { "percentile_reference": 0, @@ -1349,7 +1352,7 @@ def _calculate_neighbor_statistics(self, messages: list, current_round: int) -> } mean_messages = np.mean(counts_all_neighbors) - + return { "percentile_reference": np.percentile(counts_all_neighbors, 25), "std_dev": np.std(counts_all_neighbors), @@ -1361,10 +1364,10 @@ def _calculate_normalized_messages(self, messages_count: int, neighbor_stats: di """Calculate normalized message score with relative and extra penalties.""" normalized_messages = 1.0 penalties_applied = [] - + relative_increase = self._calculate_relative_increase(messages_count, neighbor_stats["percentile_reference"]) dynamic_margin = self._calculate_dynamic_margin(neighbor_stats) - + if relative_increase > dynamic_margin: penalty_ratio = self._calculate_penalty_ratio(relative_increase, dynamic_margin) normalized_messages *= np.exp(-(penalty_ratio**2)) @@ -1400,7 +1403,7 @@ def _calculate_penalty_ratio(self, relative_increase: float, dynamic_margin: flo def _should_apply_extra_penalty(self, messages_count: int, neighbor_stats: dict) -> bool: """Determine if extra penalty should be applied.""" - return (neighbor_stats["mean_messages"] > 0 and + return (neighbor_stats["mean_messages"] > 0 and messages_count > neighbor_stats["augment_mean"]) def _calculate_extra_penalty_factor(self, messages_count: int, neighbor_stats: dict) -> float: @@ -1408,7 +1411,7 @@ def _calculate_extra_penalty_factor(self, messages_count: int, neighbor_stats: d epsilon = 1e-6 mean_messages = neighbor_stats["mean_messages"] augment_mean = neighbor_stats["augment_mean"] - + extra_penalty = (messages_count - mean_messages) / (mean_messages + epsilon) amplification = 1 + (augment_mean / (mean_messages + epsilon)) return extra_penalty * amplification @@ -1417,27 +1420,27 @@ def _apply_historical_penalty(self, normalized_messages: float, addr: str, nei: """Apply historical penalty based on previous round's score.""" if current_round <= 1: return normalized_messages - + prev_data = ( self.number_message_history.get((addr, nei), {}) .get(current_round - 1, {}) ) - + prev_score = prev_data.get("normalized_messages") was_previously_penalized = prev_data.get("was_penalized", False) - + if prev_score is not None and prev_score < self.HISTORICAL_PENALTY_THRESHOLD: original_score = normalized_messages - + if was_previously_penalized: penalty_factor = self.HISTORICAL_PENALTY_THRESHOLD * 0.8 logging.debug(f"Repeated penalty applied to {nei}: stricter historical penalty") else: penalty_factor = self.HISTORICAL_PENALTY_THRESHOLD - + normalized_messages *= penalty_factor logging.debug(f"Historical penalty applied to {nei}: {original_score:.4f} -> {normalized_messages:.4f} (prev_score: {prev_score:.4f}, was_penalized: {was_previously_penalized})") - + return normalized_messages def _store_message_history(self, addr: str, nei: str, current_round: int, normalized_messages: float): @@ -1445,9 +1448,9 @@ def _store_message_history(self, addr: str, nei: str, current_round: int, normal key = (addr, nei) if key not in self.number_message_history: self.number_message_history[key] = {} - + was_penalized = normalized_messages < 1.0 - + self.number_message_history[key][current_round] = { "normalized_messages": normalized_messages, "was_penalized": was_penalized, @@ -1464,9 +1467,9 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali """ try: key = (addr, nei) - + self._initialize_message_history_entry(key, current_round, messages_number_message_normalized) - + if messages_number_message_normalized > 0 and current_round >= 1: avg_number_message = self._calculate_weighted_average_positive(key, current_round, messages_number_message_normalized) elif messages_number_message_normalized == 0 and current_round >= 1: @@ -1478,7 +1481,7 @@ def save_number_message_history(self, addr, nei, messages_number_message_normali self.number_message_history[key][current_round]["avg_number_message"] = avg_number_message return avg_number_message - + except Exception: logging.exception("Error saving number_message history") return -1 @@ -1524,7 +1527,7 @@ async def save_reputation_history_in_memory(self, addr: str, nei: str, reputatio Args: addr: The node's identifier - nei: The neighboring node identifier + nei: The neighboring node identifier reputation: The reputation value to save Returns: @@ -1533,27 +1536,27 @@ async def save_reputation_history_in_memory(self, addr: str, nei: str, reputatio try: key = (addr, nei) current_round = await self._engine.get_round() - + if key not in self.reputation_history: self.reputation_history[key] = {} self.reputation_history[key][current_round] = reputation rounds = sorted(self.reputation_history[key].keys(), reverse=True)[:2] - + if len(rounds) >= 2: current_rep = self.reputation_history[key][rounds[0]] previous_rep = self.reputation_history[key][rounds[1]] - + current_weight = self.REPUTATION_CURRENT_WEIGHT previous_weight = self.REPUTATION_FEEDBACK_WEIGHT avg_reputation = (current_rep * current_weight) + (previous_rep * previous_weight) - + logging.info(f"Current reputation: {current_rep}, Previous reputation: {previous_rep}") logging.info(f"Reputation ponderated: {avg_reputation}") else: avg_reputation = reputation - + return avg_reputation except Exception: @@ -1577,23 +1580,23 @@ def calculate_similarity_from_metrics(self, nei: str, current_round: int) -> flo return 0.0 relevant_metrics = [ - metric for metric in metrics_instance.similarity + metric for metric in metrics_instance.similarity if metric.get("nei") == nei and metric.get("current_round") == current_round ] - + if not relevant_metrics: relevant_metrics = [ - metric for metric in metrics_instance.similarity + metric for metric in metrics_instance.similarity if metric.get("nei") == nei ] - + if not relevant_metrics: return 0.0 neighbor_metric = relevant_metrics[-1] similarity_weights = { "cosine": 0.25, - "euclidean": 0.25, + "euclidean": 0.25, "manhattan": 0.25, "pearson_correlation": 0.25, } @@ -1604,7 +1607,7 @@ def calculate_similarity_from_metrics(self, nei: str, current_round: int) -> flo ) return max(0.0, min(1.0, similarity_value)) - + except Exception: return 0.0 @@ -1620,9 +1623,9 @@ async def calculate_reputation(self, ae: AggregationEvent): (updates, _, _) = await ae.get_event_data() await self._log_reputation_calculation_start() - + neighbors = set(await self._engine._cm.get_addrs_current_connections(only_direct=True)) - + await self._process_neighbor_metrics(neighbors) await self._calculate_reputation_by_factor(neighbors) await self._handle_initial_reputation() @@ -1644,7 +1647,7 @@ async def _process_neighbor_metrics(self, neighbors): metrics = await self.calculate_value_metrics( self._addr, nei, metrics_active=self._metrics ) - + if self._weighting_factor == "dynamic": await self._process_dynamic_metrics(nei, metrics) elif self._weighting_factor == "static" and await self._engine.get_round() >= 1: @@ -1653,7 +1656,7 @@ async def _process_neighbor_metrics(self, neighbors): async def _process_dynamic_metrics(self, nei, metrics): """Process metrics for dynamic weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics - + self.calculate_weighted_values( metric_messages_number, metric_similarity, @@ -1669,7 +1672,7 @@ async def _process_dynamic_metrics(self, nei, metrics): async def _process_static_metrics(self, nei, metrics): """Process metrics for static weighting factor.""" (metric_messages_number, metric_similarity, metric_fraction, metric_model_arrival_latency) = metrics - + metric_values_dict = { "num_messages": metric_messages_number, "model_similarity": metric_similarity, @@ -1686,7 +1689,7 @@ async def _calculate_reputation_by_factor(self, neighbors): async def _handle_initial_reputation(self): """Handle reputation initialization for the first round.""" if await self._engine.get_round() < 1 and self._enabled: - federation = self._engine.config.participant["network_args"]["neighbors"].split() + federation = self._engine.config.participant["network_args"]["neighbors"] await self.init_reputation( federation_nodes=federation, round_num=await self._engine.get_round(), @@ -1698,7 +1701,7 @@ async def _process_feedback(self): """Process and include feedback in reputation.""" status = await self.include_feedback_in_reputation() current_round = await self._engine.get_round() - + if status: logging.info(f"Feedback included in reputation at round {current_round}") else: @@ -1735,7 +1738,7 @@ async def send_reputation_to_neighbors(self, neighbors): def create_graphic_reputation(self, addr: str, round_num: int): """ Log reputation data for visualization. - + Args: addr: The node address round_num: The round number for logging step @@ -1746,7 +1749,7 @@ def create_graphic_reputation(self, addr: str, round_num: int): for node_id, data in self.reputation.items() if data.get("reputation") is not None } - + if valid_reputations: reputation_data = {f"Reputation/{addr}": valid_reputations} self._engine.trainer._logger.log_data(reputation_data, step=round_num) @@ -1954,26 +1957,26 @@ def _recalculate_pending_latencies(self, current_round): async def recollect_similarity(self, ure: UpdateReceivedEvent): """ Collect and analyze model similarity metrics. - + Args: ure: UpdateReceivedEvent containing model and metadata """ (decoded_model, weight, nei, round_num, local) = await ure.get_event_data() - + if not (self._enabled and self._is_metric_enabled("model_similarity")): return - - if not self._engine.config.participant["adaptive_args"]["model_similarity"]: + + if not self._engine.config.participant["addons"]["reputation"]["adaptive_args"] and not self._engine.config.participant["addons"]["reputation"]["adaptive_args"]["model_similarity"]: return - + if nei == self._addr: return - + logging.info("🤖 handle_model_message | Checking model similarity") - + local_model = self._engine.trainer.get_model_parameters() similarity_values = self._calculate_all_similarity_metrics(local_model, decoded_model) - + similarity_metrics = { "timestamp": datetime.now(), "nei": nei, @@ -1996,7 +1999,7 @@ def _calculate_all_similarity_metrics(self, local_model: dict, received_model: d "jaccard": 0.0, "minkowski": 0.0, } - + similarity_functions = [ ("cosine", cosine_metric), ("euclidean", euclidean_metric), @@ -2004,29 +2007,29 @@ def _calculate_all_similarity_metrics(self, local_model: dict, received_model: d ("pearson_correlation", pearson_correlation_metric), ("jaccard", jaccard_metric), ] - + similarity_values = {} - + for name, metric_func in similarity_functions: try: similarity_values[name] = metric_func(local_model, received_model, similarity=True) except Exception: similarity_values[name] = 0.0 - + try: similarity_values["minkowski"] = minkowski_metric( local_model, received_model, p=2, similarity=True ) except Exception: similarity_values["minkowski"] = 0.0 - + return similarity_values def _store_similarity_metrics(self, nei: str, similarity_metrics: dict): """Store similarity metrics for the given neighbor.""" if nei not in self.connection_metrics: self.connection_metrics[nei] = Metrics() - + self.connection_metrics[nei].similarity.append(similarity_metrics) def _check_similarity_threshold(self, nei: str, cosine_value: float): @@ -2064,25 +2067,25 @@ async def _record_message_data(self, source: str): async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEvent): """ Collect and analyze the fraction of parameters that changed between models. - + Args: ure: UpdateReceivedEvent containing model and metadata """ (decoded_model, weight, source, round_num, local) = await ure.get_event_data() - + current_round = await self._engine.get_round() parameters_local = self._engine.trainer.get_model_parameters() - + prev_threshold = self._get_previous_threshold(source, current_round) differences = self._calculate_parameter_differences(parameters_local, decoded_model) current_threshold = self._calculate_threshold(differences, prev_threshold) - + changed_params, total_params, changes_record = self._count_changed_parameters( parameters_local, decoded_model, current_threshold ) - + fraction_changed = changed_params / total_params if total_params > 0 else 0.0 - + self._store_fraction_data(source, current_round, { "fraction_changed": fraction_changed, "total_params": total_params, @@ -2102,7 +2105,7 @@ async def recollect_fraction_of_parameters_changed(self, ure: UpdateReceivedEven def _get_previous_threshold(self, source: str, current_round: int) -> float: """Get the threshold from the previous round for the given source.""" - if (source in self.fraction_of_params_changed and + if (source in self.fraction_of_params_changed and current_round - 1 in self.fraction_of_params_changed[source]): return self.fraction_of_params_changed[source][current_round - 1][-1]["threshold"] return None @@ -2122,7 +2125,7 @@ def _calculate_threshold(self, differences: list, prev_threshold: float) -> floa """Calculate the threshold for determining parameter changes.""" if not differences: return 0 - + mean_threshold = torch.mean(torch.tensor(differences)).item() if prev_threshold is not None: return (prev_threshold + mean_threshold) / 2 @@ -2133,20 +2136,20 @@ def _count_changed_parameters(self, local_params: dict, received_params: dict, t total_params = 0 changed_params = 0 changes_record = {} - + for key in local_params.keys(): if key in received_params: local_tensor = local_params[key].cpu() received_tensor = received_params[key].cpu() diff = torch.abs(local_tensor - received_tensor) total_params += diff.numel() - + num_changed = torch.sum(diff > threshold).item() changed_params += num_changed - + if num_changed > 0: changes_record[key] = num_changed - + return changed_params, total_params, changes_record def _store_fraction_data(self, source: str, current_round: int, data: dict): @@ -2155,5 +2158,5 @@ def _store_fraction_data(self, source: str, current_round: int, data: dict): self.fraction_of_params_changed[source] = {} if current_round not in self.fraction_of_params_changed[source]: self.fraction_of_params_changed[source][current_round] = [] - - self.fraction_of_params_changed[source][current_round].append(data) \ No newline at end of file + + self.fraction_of_params_changed[source][current_round].append(data) diff --git a/nebula/addons/topologymanager.py b/nebula/addons/topologymanager.py index c29937372..eda6429cd 100755 --- a/nebula/addons/topologymanager.py +++ b/nebula/addons/topologymanager.py @@ -322,15 +322,16 @@ def update_nodes(self, config_participants): def get_neighbors_string(self, node_idx): """ - Retrieves the neighbors of a given node as a string representation. + Retrieves the neighbors of a given node as a list of string representations. - This method checks the `topology` attribute to find the neighbors of the node at the specified index (`node_idx`). It then returns a string that lists the coordinates of each neighbor. + This method checks the `topology` attribute to find the neighbors of the node at the specified index (`node_idx`). + It then returns a list that contains the coordinates of each neighbor in string format. Parameters: node_idx (int): The index of the node for which neighbors are to be retrieved. Returns: - str: A space-separated string of neighbors' coordinates in the format "latitude:longitude". + list[str]: A list of neighbors' coordinates in the format "latitude:longitude". """ logging.info(f"Getting neighbors for node {node_idx}") logging.info(f"Topology shape: {self.topology.shape}") @@ -342,9 +343,8 @@ def get_neighbors_string(self, node_idx): logging.info(f"Found neighbor at index {i}: {self.nodes[i]}") neighbors_data_strings = [f"{i[0]}:{i[1]}" for i in neighbors_data] - neighbors_data_string = " ".join(neighbors_data_strings) - logging.info(f"Neighbors of node participant_{node_idx}: {neighbors_data_string}") - return neighbors_data_string + logging.info(f"Neighbors of node participant_{node_idx}: {neighbors_data_strings}") + return neighbors_data_strings def __ring_topology(self, increase_convergence=False): """ diff --git a/nebula/config/config.py b/nebula/config/config.py index 5ef336e3a..e2d2b9cb2 100755 --- a/nebula/config/config.py +++ b/nebula/config/config.py @@ -55,7 +55,7 @@ def reset_logging_configuration(self): self.__set_default_logging(mode="a") self.__set_training_logging(mode="a") - + def shutdown_logging(self): """ Properly shuts down all loggers and their handlers in the system. @@ -87,7 +87,7 @@ def __default_config(self): def __set_default_logging(self, mode="w"): experiment_name = self.participant["scenario_args"]["name"] - self.log_dir = os.path.join(self.participant["tracking_args"]["log_dir"], experiment_name) + self.log_dir =self.participant["tracking_args"]["log_dir"] if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) self.log_filename = f"{self.log_dir}/participant_{self.participant['device_args']['idx']}" @@ -204,40 +204,46 @@ def add_participants_config(self, participants_config): def add_neighbor_from_config(self, addr): if self.participant != {}: - if self.participant["network_args"]["neighbors"] == "": - self.participant["network_args"]["neighbors"] = addr - self.participant["mobility_args"]["neighbors_distance"][addr] = None + neighbors = self.participant["network_args"]["neighbors"] + + if not neighbors: + self.participant["network_args"]["neighbors"] = [addr] + self.participant["addons"]["mobility"]["neighbors_distance"][addr] = None else: - if addr not in self.participant["network_args"]["neighbors"]: - self.participant["network_args"]["neighbors"] += " " + addr - self.participant["mobility_args"]["neighbors_distance"][addr] = None + if addr not in neighbors: + self.participant["network_args"]["neighbors"].append(addr) + self.participant["addons"]["mobility"]["neighbors_distance"][addr] = None def update_nodes_distance(self, distances: dict): - self.participant["mobility_args"]["neighbors_distance"] = {node: dist for node, (dist, _) in distances.items()} + self.participant["addons"]["mobility"]["neighbors_distance"] = {node: dist for node, (dist, _) in distances.items()} def update_neighbors_from_config(self, current_connections, dest_addr): - final_neighbors = [] - for n in current_connections: - if n != dest_addr: - final_neighbors.append(n) - - final_neighbors_string = " ".join(final_neighbors) - # Update neighbors - self.participant["network_args"]["neighbors"] = final_neighbors_string + final_neighbors = [n for n in current_connections if n != dest_addr] + + # Update neighbors como lista (no string) + self.participant["network_args"]["neighbors"] = final_neighbors + # Update neighbors location - self.participant["mobility_args"]["neighbors_distance"] = { - n: self.participant["mobility_args"]["neighbors_distance"][n] + self.participant["addons"]["mobility"]["neighbors_distance"] = { + n: self.participant["addons"]["mobility"]["neighbors_distance"][n] for n in final_neighbors - if n in self.participant["mobility_args"]["neighbors_distance"] + if n in self.participant["addons"]["mobility"]["neighbors_distance"] } - logging.info(f"Final neighbors: {final_neighbors_string} (config updated))") + + logging.info(f"Final neighbors: {final_neighbors} (config updated)") + def remove_neighbor_from_config(self, addr): - if self.participant != {}: - if self.participant["network_args"]["neighbors"] != "": - self.participant["network_args"]["neighbors"] = ( - self.participant["network_args"]["neighbors"].replace(addr, "").replace(" ", " ").strip() - ) + if self.participant: + neighbors = self.participant["network_args"]["neighbors"] + + if addr in neighbors: + neighbors.remove(addr) + self.participant["network_args"]["neighbors"] = neighbors + + if addr in self.participant["addons"]["mobility"]["neighbors_distance"]: + del self.participant["addons"]["mobility"]["neighbors_distance"][addr] + def reload_config_file(self): config_dir = self.participant["tracking_args"]["config_dir"] diff --git a/nebula/controller/database.py b/nebula/controller/database.py deleted file mode 100755 index 7a012fd8a..000000000 --- a/nebula/controller/database.py +++ /dev/null @@ -1,1283 +0,0 @@ -import asyncio -import datetime -import json -import logging -import os -import sqlite3 - -import aiosqlite -from argon2 import PasswordHasher - -user_db_file_location = None -node_db_file_location = None -scenario_db_file_location = None -notes_db_file_location = None - -_node_lock = asyncio.Lock() - -PRAGMA_SETTINGS = [ - "PRAGMA journal_mode=WAL;", - "PRAGMA synchronous=NORMAL;", - "PRAGMA journal_size_limit=1048576;", - "PRAGMA cache_size=10000;", - "PRAGMA temp_store=MEMORY;", - "PRAGMA cache_spill=0;", -] - - -async def setup_database(db_file_location): - """ - Initializes the SQLite database with the required PRAGMA settings. - - This function: - - Connects asynchronously to the specified SQLite database file. - - Applies a predefined list of PRAGMA settings to configure the database. - - Commits the changes after applying the settings. - - Args: - db_file_location (str): Path to the SQLite database file. - - Exceptions: - PermissionError: Logged if the application lacks permission to create or modify the database file. - Exception: Logs any other unexpected error that occurs during setup. - """ - try: - async with aiosqlite.connect(db_file_location) as db: - for pragma in PRAGMA_SETTINGS: - await db.execute(pragma) - await db.commit() - except PermissionError: - logging.info("No permission to create the databases. Change the default databases directory") - except Exception as e: - logging.exception(f"An error has ocurred during setup_database: {e}") - - -async def ensure_columns(conn, table_name, desired_columns): - """ - Ensures that a table contains all the desired columns, adding any that are missing. - - This function: - - Retrieves the current columns of the specified table. - - Compares them with the desired columns. - - Adds any missing columns to the table using ALTER TABLE statements. - - Args: - conn (aiosqlite.Connection): Active connection to the SQLite database. - table_name (str): Name of the table to check and modify. - desired_columns (dict): Dictionary mapping column names to their SQL definitions. - - Note: - This operation is committed immediately after all changes are applied. - """ - _c = await conn.execute(f"PRAGMA table_info({table_name});") - existing_columns = [row[1] for row in await _c.fetchall()] - for column_name, column_definition in desired_columns.items(): - if column_name not in existing_columns: - await conn.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_definition};") - await conn.commit() - - -async def initialize_databases(databases_dir): - """ - Initializes all required SQLite databases and their corresponding tables for the system. - - This function: - - Defines paths for user, node, scenario, and notes databases based on the provided directory. - - Sets up each database with appropriate PRAGMA settings. - - Creates necessary tables if they do not exist. - - Ensures all expected columns are present in each table, adding any missing ones. - - Creates a default admin user if no users are present. - - Args: - databases_dir (str): Path to the directory where the database files will be created or accessed. - - Note: - Default credentials (username and password) are taken from environment variables: - - NEBULA_DEFAULT_USER - - NEBULA_DEFAULT_PASSWORD - """ - global user_db_file_location, node_db_file_location, scenario_db_file_location, notes_db_file_location - - user_db_file_location = os.path.join(databases_dir, "users.db") - node_db_file_location = os.path.join(databases_dir, "nodes.db") - scenario_db_file_location = os.path.join(databases_dir, "scenarios.db") - notes_db_file_location = os.path.join(databases_dir, "notes.db") - - await setup_database(user_db_file_location) - await setup_database(node_db_file_location) - await setup_database(scenario_db_file_location) - await setup_database(notes_db_file_location) - - async with aiosqlite.connect(user_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS users ( - user TEXT PRIMARY KEY, - password TEXT, - role TEXT - ); - """ - ) - desired_columns = {"user": "TEXT PRIMARY KEY", "password": "TEXT", "role": "TEXT"} - await ensure_columns(conn, "users", desired_columns) - - async with aiosqlite.connect(node_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS nodes ( - uid TEXT PRIMARY KEY, - idx TEXT, - ip TEXT, - port TEXT, - role TEXT, - neighbors TEXT, - latitude TEXT, - longitude TEXT, - timestamp TEXT, - federation TEXT, - round TEXT, - scenario TEXT, - hash TEXT, - malicious TEXT - ); - """ - ) - desired_columns = { - "uid": "TEXT PRIMARY KEY", - "idx": "TEXT", - "ip": "TEXT", - "port": "TEXT", - "role": "TEXT", - "neighbors": "TEXT", - "latitude": "TEXT", - "longitude": "TEXT", - "timestamp": "TEXT", - "federation": "TEXT", - "round": "TEXT", - "scenario": "TEXT", - "hash": "TEXT", - "malicious": "TEXT", - } - await ensure_columns(conn, "nodes", desired_columns) - - async with aiosqlite.connect(scenario_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS scenarios ( - name TEXT PRIMARY KEY, - start_time TEXT, - end_time TEXT, - title TEXT, - description TEXT, - deployment TEXT, - federation TEXT, - topology TEXT, - nodes TEXT, - nodes_graph TEXT, - n_nodes TEXT, - matrix TEXT, - random_topology_probability TEXT, - dataset TEXT, - iid TEXT, - partition_selection TEXT, - partition_parameter TEXT, - model TEXT, - agg_algorithm TEXT, - rounds TEXT, - logginglevel TEXT, - report_status_data_queue TEXT, - accelerator TEXT, - network_subnet TEXT, - network_gateway TEXT, - epochs TEXT, - attack_params TEXT, - reputation TEXT, - random_geo TEXT, - latitude TEXT, - longitude TEXT, - mobility TEXT, - mobility_type TEXT, - radius_federation TEXT, - scheme_mobility TEXT, - round_frequency TEXT, - mobile_participants_percent TEXT, - additional_participants TEXT, - schema_additional_participants TEXT, - status TEXT, - role TEXT, - username TEXT, - gpu_id TEXT - ); - """ - ) - desired_columns = { - "name": "TEXT PRIMARY KEY", - "start_time": "TEXT", - "end_time": "TEXT", - "title": "TEXT", - "description": "TEXT", - "deployment": "TEXT", - "federation": "TEXT", - "topology": "TEXT", - "nodes": "TEXT", - "nodes_graph": "TEXT", - "n_nodes": "TEXT", - "matrix": "TEXT", - "random_topology_probability": "TEXT", - "dataset": "TEXT", - "iid": "TEXT", - "partition_selection": "TEXT", - "partition_parameter": "TEXT", - "model": "TEXT", - "agg_algorithm": "TEXT", - "rounds": "TEXT", - "logginglevel": "TEXT", - "report_status_data_queue": "TEXT", - "accelerator": "TEXT", - "gpu_id": "TEXT", - "network_subnet": "TEXT", - "network_gateway": "TEXT", - "epochs": "TEXT", - "attack_params": "TEXT", - "reputation": "TEXT", - "random_geo": "TEXT", - "latitude": "TEXT", - "longitude": "TEXT", - "mobility": "TEXT", - "mobility_type": "TEXT", - "radius_federation": "TEXT", - "scheme_mobility": "TEXT", - "round_frequency": "TEXT", - "mobile_participants_percent": "TEXT", - "additional_participants": "TEXT", - "schema_additional_participants": "TEXT", - "status": "TEXT", - "role": "TEXT", - "username": "TEXT", - } - await ensure_columns(conn, "scenarios", desired_columns) - - async with aiosqlite.connect(notes_db_file_location) as conn: - await conn.execute( - """ - CREATE TABLE IF NOT EXISTS notes ( - scenario TEXT PRIMARY KEY, - scenario_notes TEXT - ); - """ - ) - desired_columns = {"scenario": "TEXT PRIMARY KEY", "scenario_notes": "TEXT"} - await ensure_columns(conn, "notes", desired_columns) - - username = os.environ.get("NEBULA_DEFAULT_USER", "admin") - password = os.environ.get("NEBULA_DEFAULT_PASSWORD", "admin") - if not list_users(): - add_user(username, password, "admin") - if not verify_hash_algorithm(username): - update_user(username, password, "admin") - - -def list_users(all_info=False): - """ - Retrieves a list of users from the users database. - - Args: - all_info (bool): If True, returns full user records; otherwise, returns only usernames. Default is False. - - Returns: - list: A list of usernames or full user records depending on the all_info flag. - """ - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM users") - result = c.fetchall() - - if not all_info: - result = [user["user"] for user in result] - - return result - - -def get_user_info(user): - """ - Fetches detailed information for a specific user from the users database. - - Args: - user (str): The username to retrieve information for. - - Returns: - sqlite3.Row or None: A row containing the user's information if found, otherwise None. - """ - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - command = "SELECT * FROM users WHERE user = ?" - c.execute(command, (user,)) - result = c.fetchone() - - return result - - -def verify(user, password): - """ - Verifies whether the provided password matches the stored hashed password for a user. - - Args: - user (str): The username to verify. - password (str): The plain text password to check against the stored hash. - - Returns: - bool: True if the password is correct, False otherwise. - """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - - c.execute("SELECT password FROM users WHERE user = ?", (user,)) - result = c.fetchone() - if result: - try: - return ph.verify(result[0], password) - except: - return False - return False - - -def verify_hash_algorithm(user): - """ - Checks if the stored password hash for a user uses a supported Argon2 algorithm. - - Args: - user (str): The username to check (case-insensitive, converted to uppercase). - - Returns: - bool: True if the password hash starts with a valid Argon2 prefix, False otherwise. - """ - user = user.upper() - argon2_prefixes = ("$argon2i$", "$argon2id$") - - with sqlite3.connect(user_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - c.execute("SELECT password FROM users WHERE user = ?", (user,)) - result = c.fetchone() - if result: - password_hash = result["password"] - return password_hash.startswith(argon2_prefixes) - - return False - - -def delete_user_from_db(user): - """ - Deletes a user record from the users database. - - Args: - user (str): The username of the user to be deleted. - """ - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute("DELETE FROM users WHERE user = ?", (user,)) - - -def add_user(user, password, role): - """ - Adds a new user to the users database with a hashed password. - - Args: - user (str): The username to add (stored in uppercase). - password (str): The plain text password to hash and store. - role (str): The role assigned to the user. - """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute( - "INSERT INTO users VALUES (?, ?, ?)", - (user.upper(), ph.hash(password), role), - ) - - -def update_user(user, password, role): - """ - Updates the password and role of an existing user in the users database. - - Args: - user (str): The username to update (case-insensitive, stored as uppercase). - password (str): The new plain text password to hash and store. - role (str): The new role to assign to the user. - """ - ph = PasswordHasher() - with sqlite3.connect(user_db_file_location) as conn: - c = conn.cursor() - c.execute( - "UPDATE users SET password = ?, role = ? WHERE user = ?", - (ph.hash(password), role, user.upper()), - ) - - -def list_nodes(scenario_name=None, sort_by="idx"): - """ - Retrieves a list of nodes from the nodes database, optionally filtered by scenario and sorted. - - Args: - scenario_name (str, optional): Name of the scenario to filter nodes by. If None, returns all nodes. - sort_by (str): Column name to sort the results by. Defaults to "idx". - - Returns: - list or None: A list of sqlite3.Row objects representing nodes, or None if an error occurs. - """ - try: - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if scenario_name: - command = "SELECT * FROM nodes WHERE scenario = ? ORDER BY " + sort_by + ";" - c.execute(command, (scenario_name,)) - else: - command = "SELECT * FROM nodes ORDER BY " + sort_by + ";" - c.execute(command) - - result = c.fetchall() - - return result - except sqlite3.Error as e: - print(f"Error occurred while listing nodes: {e}") - return None - - -def list_nodes_by_scenario_name(scenario_name): - """ - Fetches all nodes associated with a specific scenario, ordered by their index as integers. - - Args: - scenario_name (str): The name of the scenario to filter nodes by. - - Returns: - list or None: A list of sqlite3.Row objects for nodes in the scenario, or None if an error occurs. - """ - try: - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - command = "SELECT * FROM nodes WHERE scenario = ? ORDER BY CAST(idx AS INTEGER) ASC;" - c.execute(command, (scenario_name,)) - result = c.fetchall() - - return result - except sqlite3.Error as e: - print(f"Error occurred while listing nodes by scenario name: {e}") - return None - - -async def update_node_record( - node_uid, - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - scenario, - run_hash, - malicious, -): - """ - Inserts or updates a node record in the database for a given scenario, ensuring thread-safe access. - - Args: - node_uid (str): Unique identifier of the node. - idx (str): Index or identifier within the scenario. - ip (str): IP address of the node. - port (str): Port used by the node. - role (str): Role of the node in the federation. - neighbors (str): Neighbors of the node (serialized). - latitude (str): Geographic latitude of the node. - longitude (str): Geographic longitude of the node. - timestamp (str): Timestamp of the last update. - federation (str): Federation identifier the node belongs to. - federation_round (str): Current federation round. - scenario (str): Scenario name the node is part of. - run_hash (str): Hash of the current run/state. - malicious (str): Indicator if the node is malicious. - - Returns: - dict or None: The updated or inserted node record as a dictionary, or None if insertion/update failed. - """ - global _node_lock - async with _node_lock: - async with aiosqlite.connect(node_db_file_location) as conn: - conn.row_factory = aiosqlite.Row - _c = await conn.cursor() - - # Check if the node already exists - await _c.execute("SELECT * FROM nodes WHERE uid = ? AND scenario = ?;", (node_uid, scenario)) - result = await _c.fetchone() - - if result is None: - # Insert new node - await _c.execute( - "INSERT INTO nodes (uid, idx, ip, port, role, neighbors, latitude, longitude, timestamp, federation, round, scenario, hash, malicious) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);", - ( - node_uid, - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - scenario, - run_hash, - malicious, - ), - ) - else: - # Update existing node - await _c.execute( - "UPDATE nodes SET idx = ?, ip = ?, port = ?, role = ?, neighbors = ?, latitude = ?, longitude = ?, timestamp = ?, federation = ?, round = ?, hash = ?, malicious = ? WHERE uid = ? AND scenario = ?;", - ( - idx, - ip, - port, - role, - neighbors, - latitude, - longitude, - timestamp, - federation, - federation_round, - run_hash, - malicious, - node_uid, - scenario, - ), - ) - - await conn.commit() - - # Fetch the updated or newly inserted row - await _c.execute("SELECT * FROM nodes WHERE uid = ? AND scenario = ?;", (node_uid, scenario)) - updated_row = await _c.fetchone() - return dict(updated_row) if updated_row else None - - -def remove_all_nodes(): - """ - Deletes all node records from the nodes database. - - This operation removes every entry in the nodes table. - - Returns: - None - """ - with sqlite3.connect(node_db_file_location) as conn: - c = conn.cursor() - command = "DELETE FROM nodes;" - c.execute(command) - - -def remove_nodes_by_scenario_name(scenario_name): - """ - Deletes all nodes associated with a specific scenario from the database. - - Args: - scenario_name (str): The name of the scenario whose nodes should be removed. - - Returns: - None - """ - with sqlite3.connect(node_db_file_location) as conn: - c = conn.cursor() - command = "DELETE FROM nodes WHERE scenario = ?;" - c.execute(command, (scenario_name,)) - - -def get_all_scenarios(username, role, sort_by="start_time"): - """ - Retrieve all scenarios from the database filtered by user role and sorted by a specified field. - - Parameters: - username (str): The username of the requesting user. - role (str): The role of the user, e.g., "admin" or regular user. - sort_by (str, optional): The field name to sort the results by. Defaults to "start_time". - - Returns: - list[sqlite3.Row]: A list of scenario records as SQLite Row objects. - - Behavior: - - Admin users retrieve all scenarios. - - Non-admin users retrieve only scenarios associated with their username. - - Sorting by "start_time" applies custom datetime ordering. - - Other sort fields are applied directly in the ORDER BY clause. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - if role == "admin": - if sort_by == "start_time": - command = """ - SELECT * FROM scenarios - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); - """ - c.execute(command) - else: - command = "SELECT * FROM scenarios ORDER BY ?;" - c.execute(command, (sort_by,)) - else: - if sort_by == "start_time": - command = """ - SELECT * FROM scenarios - WHERE username = ? - ORDER BY strftime('%Y-%m-%d %H:%M:%S', substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8)); - """ - c.execute(command, (username,)) - else: - command = "SELECT * FROM scenarios WHERE username = ? ORDER BY ?;" - c.execute( - command, - ( - username, - sort_by, - ), - ) - result = c.fetchall() - - return result - - -def get_all_scenarios_and_check_completed(username, role, sort_by="start_time"): - """ - Retrieve all scenarios with detailed fields and update the status of running scenarios if their federation is completed. - - Parameters: - username (str): The username of the requesting user. - role (str): The role of the user, e.g., "admin" or regular user. - sort_by (str, optional): The field name to sort the results by. Defaults to "start_time". - - Returns: - list[sqlite3.Row]: A list of scenario records including name, username, title, start_time, model, dataset, rounds, and status. - - Behavior: - - Admin users retrieve all scenarios. - - Non-admin users retrieve only scenarios associated with their username. - - Scenarios are sorted by start_time with special handling for null or empty values. - - For scenarios with status "running", checks if federation is completed: - - If completed, updates the scenario status to "completed". - - Refreshes the returned scenario list after updates. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if role == "admin": - if sort_by == "start_time": - command = """ - SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios - ORDER BY - CASE - WHEN start_time IS NULL OR start_time = '' THEN 1 - ELSE 0 - END, - strftime( - '%Y-%m-%d %H:%M:%S', - substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) - ); - """ - c.execute(command) - else: - command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios ORDER BY ?;" - c.execute(command, (sort_by,)) - result = c.fetchall() - else: - if sort_by == "start_time": - command = """ - SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios - WHERE username = ? - ORDER BY - CASE - WHEN start_time IS NULL OR start_time = '' THEN 1 - ELSE 0 - END, - strftime( - '%Y-%m-%d %H:%M:%S', - substr(start_time, 7, 4) || '-' || substr(start_time, 4, 2) || '-' || substr(start_time, 1, 2) || ' ' || substr(start_time, 12, 8) - ); - """ - c.execute(command, (username,)) - else: - command = "SELECT name, username, title, start_time, model, dataset, rounds, status FROM scenarios WHERE username = ? ORDER BY ?;" - c.execute( - command, - ( - username, - sort_by, - ), - ) - result = c.fetchall() - - for scenario in result: - if scenario["status"] == "running": - if check_scenario_federation_completed(scenario["name"]): - scenario_set_status_to_completed(scenario["name"]) - result = get_all_scenarios(username, role) - - return result - - -def scenario_update_record(name, start_time, end_time, scenario, status, role, username): - """ - Insert a new scenario record or update an existing one in the database based on the scenario name. - - Parameters: - name (str): The unique name identifier of the scenario. - start_time (str): The start time of the scenario. - end_time (str): The end time of the scenario. - scenario (object): An object containing detailed scenario attributes. - status (str): The current status of the scenario. - role (str): The role of the user performing the operation. - username (str): The username of the user performing the operation. - - Behavior: - - Checks if a scenario with the given name exists. - - If not, inserts a new record with all scenario details. - - If exists, updates the existing record with the provided data. - - Commits the transaction to persist changes. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - select_command = "SELECT * FROM scenarios WHERE name = ?;" - c.execute(select_command, (name,)) - result = c.fetchone() - - if result is None: - insert_command = """ - INSERT INTO scenarios ( - name, - start_time, - end_time, - title, - description, - deployment, - federation, - topology, - nodes, - nodes_graph, - n_nodes, - matrix, - random_topology_probability, - dataset, - iid, - partition_selection, - partition_parameter, - model, - agg_algorithm, - rounds, - logginglevel, - report_status_data_queue, - accelerator, - gpu_id, - network_subnet, - network_gateway, - epochs, - attack_params, - reputation, - random_geo, - latitude, - longitude, - mobility, - mobility_type, - radius_federation, - scheme_mobility, - round_frequency, - mobile_participants_percent, - additional_participants, - schema_additional_participants, - status, - role, - username - ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? - ); - """ - c.execute( - insert_command, - ( - name, - start_time, - end_time, - scenario.scenario_title, - scenario.scenario_description, - scenario.deployment, - scenario.federation, - scenario.topology, - json.dumps(scenario.nodes), - json.dumps(scenario.nodes_graph), - scenario.n_nodes, - json.dumps(scenario.matrix), - scenario.random_topology_probability, - scenario.dataset, - scenario.iid, - scenario.partition_selection, - scenario.partition_parameter, - scenario.model, - scenario.agg_algorithm, - scenario.rounds, - scenario.logginglevel, - scenario.report_status_data_queue, - scenario.accelerator, - json.dumps(scenario.gpu_id), - scenario.network_subnet, - scenario.network_gateway, - scenario.epochs, - json.dumps(scenario.attack_params), - json.dumps(scenario.reputation), - scenario.random_geo, - scenario.latitude, - scenario.longitude, - scenario.mobility, - scenario.mobility_type, - scenario.radius_federation, - scenario.scheme_mobility, - scenario.round_frequency, - scenario.mobile_participants_percent, - json.dumps(scenario.additional_participants), - scenario.schema_additional_participants, - status, - role, - username, - ), - ) - else: - update_command = """ - UPDATE scenarios SET - start_time = ?, - end_time = ?, - title = ?, - description = ?, - deployment = ?, - federation = ?, - topology = ?, - nodes = ?, - nodes_graph = ?, - n_nodes = ?, - matrix = ?, - random_topology_probability = ?, - dataset = ?, - iid = ?, - partition_selection = ?, - partition_parameter = ?, - model = ?, - agg_algorithm = ?, - rounds = ?, - logginglevel = ?, - report_status_data_queue = ?, - accelerator = ?, - gpu_id = ?, - network_subnet = ?, - network_gateway = ?, - epochs = ?, - attack_params = ?, - reputation = ?, - random_geo = ?, - latitude = ?, - longitude = ?, - mobility = ?, - mobility_type = ?, - radius_federation = ?, - scheme_mobility = ?, - round_frequency = ?, - mobile_participants_percent = ?, - additional_participants = ?, - schema_additional_participants = ?, - status = ?, - role = ?, - username = ? - WHERE name = ?; - """ - c.execute( - update_command, - ( - start_time, - end_time, - scenario.scenario_title, - scenario.scenario_description, - scenario.deployment, - scenario.federation, - scenario.topology, - json.dumps(scenario.nodes), - json.dumps(scenario.nodes_graph), - scenario.n_nodes, - json.dumps(scenario.matrix), - scenario.random_topology_probability, - scenario.dataset, - scenario.iid, - scenario.partition_selection, - scenario.partition_parameter, - scenario.model, - scenario.agg_algorithm, - scenario.rounds, - scenario.logginglevel, - scenario.report_status_data_queue, - scenario.accelerator, - json.dumps(scenario.gpu_id), - scenario.network_subnet, - scenario.network_gateway, - scenario.epochs, - json.dumps(scenario.attack_params), - json.dumps(scenario.reputation), - scenario.random_geo, - scenario.latitude, - scenario.longitude, - scenario.mobility, - scenario.mobility_type, - scenario.radius_federation, - scenario.scheme_mobility, - scenario.round_frequency, - scenario.mobile_participants_percent, - json.dumps(scenario.additional_participants), - scenario.schema_additional_participants, - status, - role, - username, - name, - ), - ) - - conn.commit() - - -def scenario_set_all_status_to_finished(): - """ - Set the status of all currently running scenarios to "finished" and update their end time to the current datetime. - - Behavior: - - Finds all scenarios with status "running". - - Updates their status to "finished". - - Sets the end_time to the current timestamp. - - Commits the changes to the database. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - current_time = str(datetime.datetime.now()) - c.execute("UPDATE scenarios SET status = 'finished', end_time = ? WHERE status = 'running';", (current_time,)) - conn.commit() - - -def scenario_set_status_to_finished(scenario_name): - """ - Set the status of a specific scenario to "finished" and update its end time to the current datetime. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to update. - - Behavior: - - Updates the scenario's status to "finished". - - Sets the end_time to the current timestamp. - - Commits the update to the database. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - current_time = str(datetime.datetime.now()) - c.execute( - "UPDATE scenarios SET status = 'finished', end_time = ? WHERE name = ?;", (current_time, scenario_name) - ) - conn.commit() - - -def scenario_set_status_to_completed(scenario_name): - """ - Set the status of a specific scenario to "completed". - - Parameters: - scenario_name (str): The unique name identifier of the scenario to update. - - Behavior: - - Updates the scenario's status to "completed". - - Commits the change to the database. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("UPDATE scenarios SET status = 'completed' WHERE name = ?;", (scenario_name,)) - conn.commit() - - -def get_running_scenario(username=None, get_all=False): - """ - Retrieve running or completed scenarios from the database, optionally filtered by username. - - Parameters: - username (str, optional): The username to filter scenarios by. If None, no user filter is applied. - get_all (bool, optional): If True, returns all matching scenarios; otherwise returns only one. Defaults to False. - - Returns: - sqlite3.Row or list[sqlite3.Row]: A single scenario record or a list of scenario records matching the criteria. - - Behavior: - - Filters scenarios with status "running". - - Applies username filter if provided. - - Returns either one or all matching records depending on get_all. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - - if username: - command = """ - SELECT * FROM scenarios - WHERE (status = ?) AND username = ?; - """ - c.execute(command, ("running", username)) - - result = c.fetchone() - else: - command = "SELECT * FROM scenarios WHERE status = ?;" - c.execute(command, ("running",)) - if get_all: - result = c.fetchall() - else: - result = c.fetchone() - - return result - - -def get_completed_scenario(): - """ - Retrieve a single scenario with status "completed" from the database. - - Returns: - sqlite3.Row: A scenario record with status "completed", or None if no such scenario exists. - - Behavior: - - Fetches the first scenario found with status "completed". - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - command = "SELECT * FROM scenarios WHERE status = ?;" - c.execute(command, ("completed",)) - result = c.fetchone() - - return result - - -def get_scenario_by_name(scenario_name): - """ - Retrieve a scenario record by its unique name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario. - - Returns: - sqlite3.Row: The scenario record matching the given name, or None if not found. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM scenarios WHERE name = ?;", (scenario_name,)) - result = c.fetchone() - - return result - - -def get_user_by_scenario_name(scenario_name): - """ - Retrieve the username associated with a given scenario name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario. - - Returns: - str: The username linked to the specified scenario, or None if not found. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT username FROM scenarios WHERE name = ?;", (scenario_name,)) - result = c.fetchone() - - return result["username"] - - -def remove_scenario_by_name(scenario_name): - """ - Delete a scenario from the database by its unique name. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to be removed. - - Behavior: - - Removes the scenario record matching the given name. - - Commits the deletion to the database. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("DELETE FROM scenarios WHERE name = ?;", (scenario_name,)) - conn.commit() - - -def check_scenario_federation_completed(scenario_name): - """ - Check if all nodes in a given scenario have completed the required federation rounds. - - Parameters: - scenario_name (str): The unique name identifier of the scenario to check. - - Returns: - bool: True if all nodes have completed the total rounds specified for the scenario, False otherwise or if an error occurs. - - Behavior: - - Retrieves the total number of rounds defined for the scenario. - - Fetches the current round progress of all nodes in that scenario. - - Returns True only if every node has reached the total rounds. - - Handles database errors and missing scenario cases gracefully. - """ - try: - # Connect to the scenario database to get the total rounds for the scenario - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT rounds FROM scenarios WHERE name = ?;", (scenario_name,)) - scenario = c.fetchone() - - if not scenario: - raise ValueError(f"Scenario '{scenario_name}' not found.") - - total_rounds = scenario["rounds"] - - # Connect to the node database to check the rounds for each node - with sqlite3.connect(node_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT round FROM nodes WHERE scenario = ?;", (scenario_name,)) - nodes = c.fetchall() - - if len(nodes) == 0: - return False - - # Check if all nodes have completed the total rounds - total_rounds_str = str(total_rounds) - return all(str(node["round"]) == total_rounds_str for node in nodes) - - except sqlite3.Error as e: - print(f"Database error: {e}") - return False - except Exception as e: - print(f"An error occurred: {e}") - return False - - -def check_scenario_with_role(role, scenario_name): - """ - Verify if a scenario exists with a specific role and name. - - Parameters: - role (str): The role associated with the scenario (e.g., "admin", "user"). - scenario_name (str): The unique name identifier of the scenario. - - Returns: - bool: True if a scenario with the given role and name exists, False otherwise. - """ - with sqlite3.connect(scenario_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute( - "SELECT * FROM scenarios WHERE role = ? AND name = ?;", - ( - role, - scenario_name, - ), - ) - result = c.fetchone() - - return result is not None - - -def save_notes(scenario, notes): - """ - Save or update notes associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario. - notes (str): The textual notes to be saved for the scenario. - - Behavior: - - Inserts new notes if the scenario does not exist in the database. - - Updates existing notes if the scenario already has notes saved. - - Handles SQLite integrity and general database errors gracefully. - """ - try: - with sqlite3.connect(notes_db_file_location) as conn: - c = conn.cursor() - c.execute( - """ - INSERT INTO notes (scenario, scenario_notes) VALUES (?, ?) - ON CONFLICT(scenario) DO UPDATE SET scenario_notes = excluded.scenario_notes; - """, - (scenario, notes), - ) - conn.commit() - except sqlite3.IntegrityError as e: - print(f"SQLite integrity error: {e}") - except sqlite3.Error as e: - print(f"SQLite error: {e}") - - -def get_notes(scenario): - """ - Retrieve notes associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario. - - Returns: - sqlite3.Row or None: The notes record for the given scenario, or None if no notes exist. - """ - with sqlite3.connect(notes_db_file_location) as conn: - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM notes WHERE scenario = ?;", (scenario,)) - result = c.fetchone() - - return result - - -def remove_note(scenario): - """ - Delete the note associated with a specific scenario. - - Parameters: - scenario (str): The unique identifier of the scenario whose note should be removed. - """ - with sqlite3.connect(notes_db_file_location) as conn: - c = conn.cursor() - c.execute("DELETE FROM notes WHERE scenario = ?;", (scenario,)) - conn.commit() - - -if __name__ == "__main__": - """ - Entry point for the script to print the list of users. - - When executed directly, this block calls the `list_users()` function - and prints its returned list of users. - """ - print(list_users()) diff --git a/app/databases/__init__.py b/nebula/controller/federation/__init__.py similarity index 100% rename from app/databases/__init__.py rename to nebula/controller/federation/__init__.py diff --git a/nebula/controller/federation/controllers/__init__.py b/nebula/controller/federation/controllers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/controller/federation/controllers/docker_federation_controller.py b/nebula/controller/federation/controllers/docker_federation_controller.py new file mode 100644 index 000000000..cadda93d1 --- /dev/null +++ b/nebula/controller/federation/controllers/docker_federation_controller.py @@ -0,0 +1,621 @@ +import asyncio +import glob +import json +import os +import shutil +from nebula.utils import DockerUtils, APIUtils +import docker +from nebula.controller.federation.federation_controller import FederationController +from nebula.controller.federation.scenario_builder import ScenarioBuilder +from nebula.controller.federation.utils_requests import factory_requests +from nebula.controller.federation.utils_requests import RemoveScenarioRequest, NodeUpdateRequest, NodeDoneRequest +from typing import Dict +from nebula.config.config import Config +from nebula.core.utils.certificate import generate_ca_certificate +from nebula.core.utils.locker import Locker + +class NebulaFederationDocker(): + def __init__(self): + self.scenario_name = "" + self.participants_alive = 0 + self.round_per_participant = {} + self.additionals_participants = {} + self.additionals_deployables = [] + self.config = Config(entity="FederationController") + self.network_name = "" + self.base_network_name = "" + self.base = "" + self.last_index_deployed: int = 0 + self.federation_round: int = 0 + self.federation_deployment_lock = Locker("federation_deployment_lock", async_lock=True) + self.participants_alive_lock = Locker("participants_alive_lock", async_lock=True) + self.config_dir = "" + self.log_dir = "" + + async def get_additionals_to_be_deployed(self, config) -> list: + async with self.federation_deployment_lock: + if not self.additionals_participants: + return False + + participant_idx = int(config["device_args"]["idx"]) + participant_round = int(config["federation_args"]["round"]) + self.round_per_participant[participant_idx] = participant_round + self.federation_round = min(self.round_per_participant.values()) + + self.additionals_deployables = [ + idx + for idx, round in self.additionals_participants.items() + if self.federation_round >= round + ] + + additionals_deployables = self.additionals_deployables.copy() + for idx in additionals_deployables: + self.additionals_participants.pop(idx) + return additionals_deployables + + async def is_experiment_finish(self): + async with self.participants_alive_lock: + self.participants_alive -= 1 + if self.participants_alive <= 0: + return True + else: + return False + +class DockerFederationController(FederationController): + + def __init__(self, hub_url, logger): + super().__init__(hub_url, logger) + self.root_path = "" + self.host_platform = "" + self.config_dir = "" + self.log_dir = "" + self.cert_dir = "" + self.advanced_analytics = "" + self.url = "" + self._nebula_federations_pool: dict[str, NebulaFederationDocker] = {} + self._federations_dict_lock = Locker("federations_dict_lock", async_lock=True) + + @property + def nfp(self): + """Nebula Federations Pool""" + return self._nebula_federations_pool + + """ ############################### + # ENDPOINT CALLBACKS # + ############################### + """ + + async def run_scenario(self, federation_id: str, scenario_data: Dict, user: str): + #TODO maintain files on memory, not read them again + federation = await self._add_nebula_federation_to_pool(federation_id, user) + scenario_info = {} + if federation: + scenario_builder = ScenarioBuilder(federation_id, user=user) + await self._initialize_scenario(scenario_builder, scenario_data, federation) + generate_ca_certificate(dir_path=self.cert_dir) + await self._load_configuration_and_start_nodes(scenario_builder, federation) + self._start_initial_nodes(scenario_builder, federation) + scenario_info = scenario_builder.get_scenario_info() + try: + nebula_federation = self.nfp[federation_id] + nebula_federation.scenario_name = scenario_builder.get_scenario_name() + except Exception as e: + self.logger.info(f"ERROR: federation ID: ({federation_id}) not found on pool..") + return None + else: + self.logger.info(f"ERROR: federation ID: ({federation_id}) already exists..") + asyncio.create_task(self.stop_scenario(federation_id)) + asyncio.create_task(self.remove_scenario(federation_id, RemoveScenarioRequest(experiment_type="docker", user=user, scenario_name=nebula_federation.scenario_name))) + return scenario_info + + async def stop_scenario(self, federation_id: str): + """ + Remove all participant containers and the scenario network. + Reads ALL scenario.metadata and removes all listed containers and the network, then deletes the metadata file. + Also forcibly stops and removes any containers still attached to the network before removing it. + """ + await asyncio.sleep(20) + federation = await self._remove_nebula_federation_from_pool(federation_id) + if not federation: + return False + + client = docker.from_env() + metadata_path = os.path.join(federation.config_dir, "scenario.metadata") + + if not os.path.exists(metadata_path): + self.logger.info(f"ERROR {metadata_path} - no 'scenario.metadata' found") + return False + + with open(metadata_path) as f: + meta = json.load(f) + # Remove containers listed in metadata + for name in meta.get("containers", []): + try: + container = client.containers.get(name) + container.remove(force=True) + self.logger.info(f"Removed scenario container {name}") + except Exception as e: + self.logger.info(f"Could not remove scenario container {name}: {e}") + # Remove network, but first forcibly remove any containers still attached + network_name = meta.get("network") + if network_name: + try: + network = client.networks.get(network_name) + attached_containers = network.attrs.get("Containers") or {} + for container_id in attached_containers: + try: + c = client.containers.get(container_id) + c.remove(force=True) + self.logger.info(f"Force-removed container {c.name} attached to {network_name}") + except Exception as e: + self.logger.info(f"Could not force-remove container {container_id}: {e}") + network.remove() + self.logger.info(f"Removed scenario network {network_name}") + except Exception as e: + self.logger.info(f"Could not remove scenario network {network_name}: {e}") + # Remove metadata file + try: + os.remove(metadata_path) + except Exception as e: + self.logger.info(f"Could not remove scenario.metadata: {e}") + return False + + return True #TODO care about cases + + async def update_nodes(self, federation_id: str, node_update_request: NodeUpdateRequest): + config = node_update_request.config + scenario_name = config["scenario_args"]["name"] + fed_id = config["scenario_args"]["federation_id"] + + try: + nebula_federation = self.nfp[fed_id] + self.logger.info(f"Update received from node on federation ID: ({fed_id})") + last_fed_round = nebula_federation.federation_round + additionals = await nebula_federation.get_additionals_to_be_deployed(config) # It modifies if neccesary the federation round + if additionals: + current_fed_round = nebula_federation.federation_round + adds_deployed = set() + if current_fed_round != last_fed_round: + self.logger.info(f"Federation Round updating for ID: ({fed_id}), current value: {current_fed_round}") + for index in additionals: + if index in adds_deployed: + continue + + for idx, node in enumerate(nebula_federation.config.participants): + if index == idx: + if index in additionals: + self.logger.info(f"Deploying additional participant: {index}") + deployed_successfully = self._start_node(nebula_federation.scenario_name, node, nebula_federation.network_name, nebula_federation.base_network_name, nebula_federation.base, nebula_federation.last_index_deployed, nebula_federation) + if deployed_successfully: + self.logger.info(f"Deployment successfully for additional participant: {index}") + nebula_federation.last_index_deployed += 1 + #additionals.remove(index) + adds_deployed.add(index) + payload = node_update_request.model_dump() + asyncio.create_task(self._send_to_hub("update", payload, federation_id=fed_id)) + return {"message": "Node updated successfully in Federation Controller"} + except Exception as e: + self.logger.info(f"ERROR: federation ID: ({fed_id}), {e}") + return {"message": "Node updated failed in Federation Controller"} + + async def node_done(self, federation_id: str, node_done_request: NodeDoneRequest): + nebula_federation = self.nfp[federation_id] + self.logger.info(f"Node-Done received from node on federation ID: ({federation_id})") + + if await nebula_federation.is_experiment_finish(): + payload = node_done_request.model_dump() + self.logger.info(f"All nodes have finished on federation ID: ({federation_id}), reporting to hub..") + await self._remove_nebula_federation_from_pool(federation_id) + asyncio.create_task(self._send_to_hub("finish", payload, federation_id=federation_id)) + + payload = node_done_request.model_dump() + asyncio.create_task(self._send_to_hub("done", payload, federation_id=federation_id)) + return {"message": "Nodes done received successfully"} + + async def remove_scenario(self, federation_id: str, remove_scenario_request: RemoveScenarioRequest): + await asyncio.sleep(40) + if(await self._check_active_federation(federation_id)): + self.logger.info(f"WARNING: Cannot remove files from active federation: ({federation_id})") + return False + + folder_name = remove_scenario_request.user+"_"+remove_scenario_request.scenario_name + scenario_config_path = os.path.join(self.config_dir, folder_name) + scenario_log_path = os.path.join(self.log_dir, folder_name) + + if not os.path.exists(scenario_config_path): + self.logger.info(f"ERROR {scenario_config_path} - no config folder found") + if not os.path.exists(scenario_log_path): + self.logger.info(f"ERROR {scenario_log_path} - no log folder found") + + try: + shutil.rmtree(scenario_config_path) + self.logger.info(f"Removed config folder {scenario_config_path}") + except Exception as e: + self.logger.info(f"Could not remove config folder {scenario_config_path}: {e}") + return False + + try: + shutil.rmtree(scenario_log_path) + self.logger.info(f"Removed log folder {scenario_log_path}") + except Exception as e: + self.logger.info(f"Could not remove log folder {scenario_log_path}: {e}") + return False + + return True + + """ ############################### + # FUNCTIONALITIES # + ############################### + """ + + async def _add_nebula_federation_to_pool(self, federation_id: str, user: str): + fed = None + async with self._federations_dict_lock: + if not federation_id in self.nfp: + fed = NebulaFederationDocker() + self.nfp[federation_id] = fed + self.logger.info(f"SUCCESS: new ID: ({federation_id}) added to the pool") + else: + self.logger.info(f"ERROR: trying to add ({federation_id}) to federations pool..") + return fed + + async def _remove_nebula_federation_from_pool(self, federation_id: str) -> NebulaFederationDocker | None: + async with self._federations_dict_lock: + if federation_id in self.nfp: + federation = self.nfp.pop(federation_id) + self.logger.info(f"SUCCESS: Federation ID: ({federation_id}) removed from pool") + return federation + else: + self.logger.info(f"ERROR: trying to remove ({federation_id}) from federations pool..") + return None + + async def _check_active_federation(self, federation_id: str) -> bool: + async with self._federations_dict_lock: + if federation_id in self.nfp: + return True + else: + return False + + async def _update_federation_on_pool(self, federation_id: str, user: str, nf: NebulaFederationDocker): + updated = False + async with self._federations_dict_lock: + if not federation_id in self.nfp: + self.nfp[federation_id] = nf + updated = True + self.logger.info(f"UPDATED: federation: ({federation_id}) successfully updated") + else: + self.logger.info(f"ERROR: trying to update ({federation_id}) on federations pool..") + return updated + + async def _send_to_hub(self, operation, payload, **kwargs): + try: + url_request = self._hub_url + factory_requests(operation, **kwargs) + await APIUtils.post(url_request, payload) + except Exception as e: + self.logger.info(f"Failed to send update to Hub: {e}") + + async def _initialize_scenario(self, sb: ScenarioBuilder, scenario_data, federation: NebulaFederationDocker): + # Initialize Scenario builder using scenario_data from user + self.logger.info("🔧 Initializing Scenario Builder using scenario data") + sb.set_scenario_data(scenario_data) + scenario_name = sb.get_scenario_name(user_to=True) + + self.root_path = os.environ.get("NEBULA_ROOT_HOST") + self.host_platform = os.environ.get("NEBULA_HOST_PLATFORM") + self.config_dir = os.environ.get("NEBULA_CONFIG_DIR") + self.log_dir = os.environ.get("NEBULA_LOGS_DIR") + federation.config_dir = os.path.join(os.environ.get("NEBULA_CONFIG_DIR"), scenario_name) + federation.log_dir = os.path.join(os.environ.get("NEBULA_LOGS_DIR"), scenario_name) + self.cert_dir = os.environ.get("NEBULA_CERTS_DIR") + self.advanced_analytics = os.environ.get("NEBULA_ADVANCED_ANALYTICS", "False") == "True" + self.env_tag = os.environ.get("NEBULA_ENV_TAG", "dev") + self.prefix_tag = os.environ.get("NEBULA_PREFIX_TAG", "dev") + self.user_tag = os.environ.get("NEBULA_USER_TAG", os.environ.get("USER", "unknown")) + + self.url = f"{os.environ.get('NEBULA_CONTROLLER_HOST')}:{os.environ.get('NEBULA_FEDERATION_CONTROLLER_PORT')}" + + # Create Scenario management dirs + os.makedirs(federation.config_dir, exist_ok=True) + os.makedirs(federation.log_dir, exist_ok=True) + os.makedirs(self.cert_dir, exist_ok=True) + + # Give permissions to the directories + os.chmod(federation.config_dir, 0o777) + os.chmod(federation.log_dir, 0o777) + os.chmod(self.cert_dir, 0o777) + + # Save the scenario configuration + scenario_file = os.path.join(federation.config_dir, "scenario.json") + with open(scenario_file, "w") as f: + json.dump(scenario_data, f, sort_keys=False, indent=2) + + os.chmod(scenario_file, 0o777) + + # Save management settings + settings = { + "scenario_name": scenario_name, + "root_path": self.root_path, + "config_dir": federation.config_dir, + "log_dir": federation.log_dir, + "cert_dir": self.cert_dir, + "env": None, + } + + settings_file = os.path.join(federation.config_dir, "settings.json") + with open(settings_file, "w") as f: + json.dump(settings, f, sort_keys=False, indent=2) + + os.chmod(settings_file, 0o777) + + # Attacks assigment and mobility + self.logger.info("🔧 Building general configuration") + sb.build_general_configuration() + self.logger.info("✅ Building general configuration done") + + # Create participant configs and .json + for index, (_, node) in enumerate(sb.get_federation_nodes().items()): + self.logger.info(f"Creating .json file for participant: {index}, Configuration: {node}") + node_config = node + try: + participant_file = os.path.join(federation.config_dir, f"participant_{node_config['id']}.json") + self.logger.info(f"Filename: {participant_file}") + os.makedirs(os.path.dirname(participant_file), exist_ok=True) + except Exception as e: + self.logger.info(f"ERROR while creating files: {e}") + + try: + participant_config = sb.build_scenario_config_for_node(index, node) + #self.logger.info(f"dictionary: {participant_config}") + except Exception as e: + self.logger.info(f"ERROR while building configuration for node: {e}") + + try: + with open(participant_file, "w") as f: + json.dump(participant_config, f, sort_keys=False, indent=2) + os.chmod(participant_file, 0o777) + except Exception as e: + self.logger.info(f"ERROR while dumping configuration into files: {e}") + + self.logger.info("✅ Initializing Scenario Builder done") + + async def _load_configuration_and_start_nodes(self, sb: ScenarioBuilder, federation: NebulaFederationDocker): + self.logger.info("🔧 Loading Scenario configuration...") + # Get participants configurations + participant_files = glob.glob(f"{federation.config_dir}/participant_*.json") + participant_files.sort() + if len(participant_files) == 0: + raise ValueError("No participant files found in config folder") + + federation.config.set_participants_config(participant_files) + n_nodes = len(participant_files) + #self.logger.info(f"Number of nodes: {n_nodes}") + + sb.create_topology_manager(federation.config) + + # Update participants configuration + is_start_node = False + config_participants = [] + + additional_participants = sb.get_additional_nodes() + additional_nodes = len(additional_participants) if additional_participants else 0 + #self.logger.info(f"######## nodes: {n_nodes} + additionals: {additional_nodes} ######") + + participant_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) + + # Initial participants + self.logger.info("🔧 Building preload configuration for initial nodes...") + for i in range(n_nodes): + try: + with open(f"{federation.config_dir}/participant_" + str(i) + ".json") as f: + participant_config = json.load(f) + except Exception as e: + self.logger.info(f"ERROR: open/load participant .json") + + self.logger.info(f"Building preload conf for participant {i}") + try: + sb.build_preload_initial_node_configuration(i, participant_config, federation.log_dir, federation.config_dir, self.cert_dir, self.advanced_analytics) + except Exception as e: + self.logger.info(f"ERROR: cannot build preload configuration") + + try: + with open(f"{federation.config_dir}/participant_" + str(i) + ".json", "w") as f: + json.dump(participant_config, f, sort_keys=False, indent=2) + except Exception as e: + self.logger.info(f"ERROR: cannot dump preload configuration into participant .json file") + + config_participants.append(( + participant_config["network_args"]["ip"], + participant_config["network_args"]["port"], + participant_config["device_args"]["role"], + )) + if participant_config["device_args"]["start"]: + if not is_start_node: + is_start_node = True + else: + raise ValueError("Only one node can be start node") + + self.logger.info("✅ Building preload configuration for initial nodes done") + + federation.config.set_participants_config(participant_files) + + # Add role to the topology (visualization purposes) + sb.visualize_topology(config_participants, path=f"{federation.config_dir}/topology.png", plot=False) + + # Additional participants + self.logger.info("🔧 Building preload configuration for additional nodes...") + additional_participants_files = [] + if additional_participants: + last_participant_file = participant_files[-1] + last_participant_index = len(participant_files) + + for i, _ in enumerate(additional_participants): + additional_participant_file = f"{federation.config_dir}/participant_{last_participant_index + i}.json" + shutil.copy(last_participant_file, additional_participant_file) + + with open(additional_participant_file) as f: + participant_config = json.load(f) + + self.logger.info(f"Configuration | additional nodes | participant: {n_nodes + i}") + sb.build_preload_additional_node_configuration(last_participant_index, i, participant_config) + + with open(additional_participant_file, "w") as f: + json.dump(participant_config, f, sort_keys=False, indent=2) + + additional_participants_files.append(additional_participant_file) + + if additional_participants_files: + federation.config.add_participants_config(additional_participants_files) + + if additional_participants: + n_nodes += len(additional_participants) + + self.logger.info("✅ Building preload configuration for additional nodes done") + self.logger.info("✅ Loading Scenario configuration done") + + # Build dataset + dataset = sb.configure_dataset(federation.config_dir) + self.logger.info(f"🔧 Splitting {sb.get_dataset_name()} dataset...") + dataset.initialize_dataset() + self.logger.info(f"✅ Splitting {sb.get_dataset_name()} dataset... Done") + + def _get_network_name(self, suffix: str) -> str: + """ + Generate a standardized network name using tags. + Args: + suffix (str): Suffix for the network (default: 'net-base'). + Returns: + str: The composed network name. + """ + return f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_{suffix}" + + def _get_participant_container_name(self, scenario_name, idx: int) -> str: + """ + Generate a standardized container name for a participant using tags. + Args: + idx (int): The participant index. + Returns: + str: The composed container name. + """ + return f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_{scenario_name}_participant{idx}" + + def _start_initial_nodes(self, sb: ScenarioBuilder, federation: NebulaFederationDocker): + self.logger.info("Starting nodes using Docker Compose...") + federation.network_name = self._get_network_name(f"{sb.get_scenario_name(user_to=True)}-net-scenario") + federation.base_network_name = self._get_network_name("net-base") + + # Create the Docker network + federation.base = DockerUtils.create_docker_network(federation.network_name) + + federation.config.participants.sort(key=lambda x: x["device_args"]["idx"]) + federation.last_index_deployed = 2 + for idx, node in enumerate(federation.config.participants): + + if node["deployment_args"]["additional"]: + federation.additionals_participants[idx] = int(node["deployment_args"]["deployment_round"]) + federation.participants_alive += 1 + self.logger.info(f"Participant {idx} is additional. Round of deployment: {int(node['deployment_args']['deployment_round'])}") + else: + # deploy initial nodes + self.logger.info(f"Deployment starting for participant {idx}") + federation.round_per_participant[idx] = 0 + deployed_successfully = self._start_node(sb.get_scenario_name(user_to=True), node, federation.network_name, federation.base_network_name, federation.base, federation.last_index_deployed, federation) + if deployed_successfully: + federation.last_index_deployed += 1 + federation.participants_alive += 1 + + def _start_node(self, scenario_name, node, network_name, base_network_name, base, i, federation: NebulaFederationDocker): + success = True + client = docker.from_env() + + federation.config.participants.sort(key=lambda x: x["device_args"]["idx"]) + container_ids = [] + container_names = [] # Track names for metadata + + image = "nebula-core" + name = self._get_participant_container_name(scenario_name, node["device_args"]["idx"]) + if node["device_args"]["accelerator"] == "gpu": + environment = { + "NVIDIA_DISABLE_REQUIRE": True, + "NEBULA_LOGS_DIR": "/nebula/app/logs/", + "NEBULA_CONFIG_DIR": "/nebula/app/config/", + } + host_config = client.api.create_host_config( + binds=[f"{self.root_path}:/nebula", "/var/run/docker.sock:/var/run/docker.sock"], + privileged=True, + device_requests=[docker.types.DeviceRequest(driver="nvidia", count=-1, capabilities=[["gpu"]])], + extra_hosts={"host.docker.internal": "host-gateway"}, + ) + else: + environment = {"NEBULA_LOGS_DIR": "/nebula/app/logs/", "NEBULA_CONFIG_DIR": "/nebula/app/config/"} + host_config = client.api.create_host_config( + binds=[f"{self.root_path}:/nebula", "/var/run/docker.sock:/var/run/docker.sock"], + privileged=True, + device_requests=[], + extra_hosts={"host.docker.internal": "host-gateway"}, + ) + volumes = ["/nebula", "/var/run/docker.sock"] + start_command = "sleep 10" if node["device_args"]["start"] else "sleep 0" + command = [ + "/bin/bash", + "-c", + f"{start_command} && ifconfig && echo '{base}.1 host.docker.internal' >> /etc/hosts && python /nebula/nebula/core/node.py /nebula/app/config/{scenario_name}/participant_{node['device_args']['idx']}.json", + ] + networking_config = client.api.create_networking_config({ + network_name: client.api.create_endpoint_config( + ipv4_address=f"{base}.{i}", + ), + base_network_name: client.api.create_endpoint_config(), + }) + node["tracking_args"]["log_dir"] = federation.log_dir + node["tracking_args"]["config_dir"] = federation.config_dir + node["scenario_args"]["controller"] = self.url + node["scenario_args"]["deployment"] = "docker" + node["security_args"]["certfile"] = f"/nebula/app/certs/participant_{node['device_args']['idx']}_cert.pem" + node["security_args"]["keyfile"] = f"/nebula/app/certs/participant_{node['device_args']['idx']}_key.pem" + node["security_args"]["cafile"] = "/nebula/app/certs/ca_cert.pem" + node = json.loads(json.dumps(node).replace("192.168.50.", f"{base}.")) # TODO change this + try: + existing = client.containers.get(name) + self.logger.info(f"Container {name} already exists. Deployment may fail or cause conflicts.") + success = False + except docker.errors.NotFound: + pass # No conflict, safe to proceed + # Write the config file in config directory + with open(f"{federation.config_dir}/participant_{node['device_args']['idx']}.json", "w") as f: + json.dump(node, f, indent=4) + try: + container_id = client.api.create_container( + image=image, + name=name, + detach=True, + volumes=volumes, + environment=environment, + command=command, + host_config=host_config, + networking_config=networking_config, + ) + except Exception as e: + success = False + self.logger.info(f"Creating container {name}: {e}") + try: + client.api.start(container_id) + container_ids.append(container_id) + container_names.append(name) + self.logger.info(f"Adding name: {name} for metadata") + except Exception as e: + success = False + self.logger.info(f"Starting participant {name} error: {e}") + + # Write scenario-level metadata for cleanup + scenario_metadata = {"containers": container_names, "network": network_name} + with open(os.path.join(federation.config_dir, "scenario.metadata"), "a") as f: + if i == 2: + json.dump(scenario_metadata, f, indent=2) + else: + with open(os.path.join(federation.config_dir, "scenario.metadata"), "r") as f: + metadata = json.load(f) + metadata["containers"].extend(container_names) + with open(os.path.join(federation.config_dir, "scenario.metadata"), "w") as f: + json.dump(metadata, f, indent=2) + + return success diff --git a/nebula/controller/federation/controllers/physicall_federation_controller.py b/nebula/controller/federation/controllers/physicall_federation_controller.py new file mode 100644 index 000000000..f60c2a8e1 --- /dev/null +++ b/nebula/controller/federation/controllers/physicall_federation_controller.py @@ -0,0 +1,4 @@ +from nebula.controller.federation.federation_controller import FederationController + +class PhysicalFederationController(FederationController): + pass \ No newline at end of file diff --git a/nebula/controller/federation/controllers/processes_federation_controller.py b/nebula/controller/federation/controllers/processes_federation_controller.py new file mode 100644 index 000000000..20eef4666 --- /dev/null +++ b/nebula/controller/federation/controllers/processes_federation_controller.py @@ -0,0 +1,562 @@ +import asyncio +import glob +import json +import os +import shutil +from nebula.utils import APIUtils +import docker +from nebula.controller.federation.federation_controller import FederationController +from nebula.controller.federation.scenario_builder import ScenarioBuilder +from nebula.controller.federation.utils_requests import factory_requests +from nebula.controller.federation.utils_requests import RemoveScenarioRequest, NodeUpdateRequest, NodeDoneRequest +from typing import Dict +from fastapi import Request +from nebula.config.config import Config +from nebula.core.utils.certificate import generate_ca_certificate +from nebula.core.utils.locker import Locker + +class NebulaFederationProcesses(): + def __init__(self): + self.scenario_name = "" + self.participants_alive = 0 + self.round_per_participant = {} + self.additionals_participants = {} + self.additionals_deployables = [] + self.config = Config(entity="FederationController") + self.network_name = "" + self.base_network_name = "" + self.base = "" + self.last_index_deployed: int = 0 + self.federation_round: int = 0 + self.federation_deployment_lock = Locker("federation_deployment_lock", async_lock=True) + self.participants_alive_lock = Locker("participants_alive_lock", async_lock=True) + self.config_dir = "" + self.log_dir = "" + + async def get_additionals_to_be_deployed(self, config) -> list: + async with self.federation_deployment_lock: + if not self.additionals_participants: + return False + + participant_idx = int(config["device_args"]["idx"]) + participant_round = int(config["federation_args"]["round"]) + self.round_per_participant[participant_idx] = participant_round + self.federation_round = min(self.round_per_participant.values()) + + self.additionals_deployables = [ + idx + for idx, round in self.additionals_participants.items() + if self.federation_round >= round + ] + + additionals_deployables = self.additionals_deployables.copy() + for idx in additionals_deployables: + self.additionals_participants.pop(idx) + return additionals_deployables + + async def is_experiment_finish(self): + async with self.participants_alive_lock: + self.participants_alive -= 1 + if self.participants_alive <= 0: + return True + else: + return False + +class ProcessesFederationController(FederationController): + def __init__(self, hub_url, logger): + super().__init__(hub_url, logger) + self.root_path = "" + self.host_platform = "" + self.config_dir = "" + self.log_dir = "" + self.cert_dir = "" + self.advanced_analytics = "" + self.url = "" + + self._nebula_federations_pool: dict[tuple[str,str], NebulaFederationProcesses] = {} + self._federations_dict_lock = Locker("federations_dict_lock", async_lock=True) + + @property + def nfp(self): + """Nebula Federations Pool""" + return self._nebula_federations_pool + + """ ############################### + # ENDPOINT CALLBACKS # + ############################### + """ + + async def run_scenario(self, federation_id: str, scenario_data: Dict, user: str): + #TODO maintain files on memory, not read them again + federation = await self._add_nebula_federation_to_pool(federation_id, user) + scenario_info = {} + if federation: + scenario_builder = ScenarioBuilder(federation_id, user=user) + await self._initialize_scenario(scenario_builder, scenario_data, federation) + generate_ca_certificate(dir_path=self.cert_dir) + await self._load_configuration_and_start_nodes(scenario_builder, federation) + self._start_initial_nodes(scenario_builder, federation) + scenario_info = scenario_builder.get_scenario_info() + try: + nebula_federation = self.nfp[federation_id] + nebula_federation.scenario_name = scenario_builder.get_scenario_name() + except Exception as e: + self.logger.info(f"ERROR: federation ID: ({federation_id}) not found on pool..") + return None + else: + self.logger.info(f"ERROR: federation ID: ({federation_id}) already exists..") + return scenario_info + + async def stop_scenario(self, federation_id: str = ""): + """ + Stop running participant nodes by removing the scenario command files. + + This method deletes the 'current_scenario_commands.sh' (or '.ps1' on Windows) + file associated with a scenario. Removing this file signals the nodes to stop + by terminating their processes. + + Args: + scenario_name (str, optional): The name of the scenario to stop. If None, + all scenarios' command files will be removed. + + Notes: + - If the environment variable NEBULA_CONFIG_DIR is not set, a default + configuration directory path is used. + - Supports both Linux/macOS ('.sh') and Windows ('.ps1') script files. + - Any errors during file removal are logged with the traceback. + """ + federation = await self._remove_nebula_federation_from_pool(federation_id) + if not federation: + return False + + try: + if os.environ.get("NEBULA_HOST_PLATFORM") == "windows": + scenario_commands_file = os.path.join( + federation.config_dir, "current_scenario_commands.ps1" + ) + else: + scenario_commands_file = os.path.join( + federation.config_dir, "current_scenario_commands.sh" + ) + if os.path.exists(scenario_commands_file): + os.remove(scenario_commands_file) + self.logger.info(f"Scenario commands file removed: {scenario_commands_file}") + else: + self.logger.info(f"Scenario commands file not found: {scenario_commands_file}") + except Exception as e: + self.logger.exception(f"Error while removing current_scenario_commands file: {e}") + + async def update_nodes(self, federation_id: str, node_update_request: NodeUpdateRequest): + config = node_update_request.config + fed_id = config["scenario_args"]["federation_id"] + scenario_name = config["scenario_args"]["name"] + + try: + nebula_federation = self.nfp[fed_id] + self.logger.info(f"Update received from node on federation ID: ({fed_id})") + last_fed_round = nebula_federation.federation_round + additionals = await nebula_federation.get_additionals_to_be_deployed(config) # It modifies if neccesary the federation round + if additionals: + current_fed_round = nebula_federation.federation_round + adds_deployed = set() + if current_fed_round != last_fed_round: + self.logger.info(f"Federation Round updating for ID: ({fed_id}), current value: {current_fed_round}") + for index in additionals: + if index in adds_deployed: + continue + + for idx, node in enumerate(nebula_federation.config.participants): + if index == idx: + if index in additionals: + self.logger.info(f"Deploying additional participant: {index}") + #TODO additionals not working + self._start_node(node, nebula_federation.network_name, nebula_federation.base_network_name, nebula_federation.base, nebula_federation.last_index_deployed, nebula_federation, additional=True) + nebula_federation.last_index_deployed += 1 + additionals.remove(index) + adds_deployed.add(index) + payload = node_update_request.model_dump() + asyncio.create_task(self._send_to_hub("update", payload, federation_id=fed_id)) + return {"message": "Node updated successfully in Federation Controller"} + except Exception as e: + self.logger.info(f"ERROR: federation ID: ({fed_id}) not found on pool..") + return {"message": "Node updated failed in Federation Controller, ID not found.."} + + async def node_done(self, federation_id: str, node_done_request: NodeDoneRequest): + nebula_federation = self.nfp[federation_id] + self.logger.info(f"Node-Done received from node on federation ID: ({federation_id})") + + if await nebula_federation.is_experiment_finish(): + payload = node_done_request.model_dump() + self.logger.info(f"All nodes have finished on federation ID: ({federation_id}), reporting to hub..") + await self._remove_nebula_federation_from_pool(federation_id) + asyncio.create_task(self._send_to_hub("finish", payload, federation_id=federation_id)) + + payload = node_done_request.model_dump() + asyncio.create_task(self._send_to_hub("done", payload, federation_id=federation_id)) + return {"message": "Nodes done received successfully"} + + async def remove_scenario(self, federation_id: str, remove_scenario_request: RemoveScenarioRequest): + if(await self._check_active_federation(federation_id)): + self.logger.info(f"WARNING: Cannot remove files from active federation: ({federation_id})") + return False + + folder_name = remove_scenario_request.user+"_"+remove_scenario_request.scenario_name + scenario_config_path = os.path.join(self.config_dir, folder_name) + scenario_log_path = os.path.join(self.log_dir, folder_name) + + if not os.path.exists(scenario_config_path): + self.logger.info(f"ERROR {scenario_config_path} - no config folder found") + if not os.path.exists(scenario_log_path): + self.logger.info(f"ERROR {scenario_log_path} - no log folder found") + + try: + shutil.rmtree(scenario_config_path) + self.logger.info(f"Removed config folder {scenario_config_path}") + except Exception as e: + self.logger.info(f"Could not remove config folder {scenario_config_path}: {e}") + return False + + try: + shutil.rmtree(scenario_log_path) + self.logger.info(f"Removed log folder {scenario_log_path}") + except Exception as e: + self.logger.info(f"Could not remove log folder {scenario_log_path}: {e}") + return False + + return True + + """ ############################### + # FUNCTIONALITIES # + ############################### + """ + + async def _add_nebula_federation_to_pool(self, federation_id: str, user: str): + fed = None + async with self._federations_dict_lock: + if not federation_id in self.nfp: + fed = NebulaFederationProcesses() + self.nfp[federation_id] = fed + self.logger.info(f"SUCCESS: new ID: ({federation_id}) added to the pool") + else: + self.logger.info(f"ERROR: trying to add ({federation_id}) to federations pool..") + return fed + + async def _remove_nebula_federation_from_pool(self, federation_id: str) -> NebulaFederationProcesses | None: + async with self._federations_dict_lock: + if federation_id in self.nfp: + federation = self.nfp.pop(federation_id) + self.logger.info(f"SUCCESS: Federation ID: ({federation_id}) removed from pool") + return federation + else: + self.logger.info(f"ERROR: trying to remove ({federation_id}) from federations pool..") + return None + + async def _check_active_federation(self, federation_id: str) -> bool: + async with self._federations_dict_lock: + if federation_id in self.nfp: + return True + else: + return False + + async def _send_to_hub(self, operation, payload, **kwargs): + try: + url_request = self._hub_url + factory_requests(operation, **kwargs) + await APIUtils.post(url_request, payload) + except Exception as e: + self.logger.info(f"Failed to send update to Hub: {e}") + + async def _initialize_scenario(self, sb: ScenarioBuilder, scenario_data, federation: NebulaFederationProcesses): + # Initialize Scenario builder using scenario_data from user + self.logger.info("🔧 Initializing Scenario Builder using scenario data") + sb.set_scenario_data(scenario_data) + scenario_name = sb.get_scenario_name(user_to=True) + + self.root_path = os.environ.get("NEBULA_ROOT_HOST") + self.host_platform = os.environ.get("NEBULA_HOST_PLATFORM") + self.config_dir = os.environ.get("NEBULA_CONFIG_DIR") + self.log_dir = os.environ.get("NEBULA_LOGS_DIR") + federation.config_dir = os.path.join(os.environ.get("NEBULA_CONFIG_DIR"), scenario_name) + federation.log_dir = os.path.join(os.environ.get("NEBULA_LOGS_DIR"), scenario_name) + self.cert_dir = os.environ.get("NEBULA_CERTS_DIR") + self.advanced_analytics = os.environ.get("NEBULA_ADVANCED_ANALYTICS", "False") == "True" + self.env_tag = os.environ.get("NEBULA_ENV_TAG", "dev") + self.prefix_tag = os.environ.get("NEBULA_PREFIX_TAG", "dev") + self.user_tag = os.environ.get("NEBULA_USER_TAG", os.environ.get("USER", "unknown")) + + self.url = f"127.0.0.1:{os.environ.get('NEBULA_FEDERATION_CONTROLLER_PORT')}" + + # Create Scenario management dirs + os.makedirs(federation.config_dir, exist_ok=True) + os.makedirs(federation.log_dir, exist_ok=True) + os.makedirs(self.cert_dir, exist_ok=True) + + # Give permissions to the directories + os.chmod(federation.config_dir, 0o777) + os.chmod(federation.log_dir, 0o777) + os.chmod(self.cert_dir, 0o777) + + # Save the scenario configuration + scenario_file = os.path.join(federation.config_dir, "scenario.json") + with open(scenario_file, "w") as f: + json.dump(scenario_data, f, sort_keys=False, indent=2) + + os.chmod(scenario_file, 0o777) + + # Save management settings + settings = { + "scenario_name": scenario_name, + "root_path": self.root_path, + "config_dir": federation.config_dir, + "log_dir": federation.log_dir, + "cert_dir": self.cert_dir, + "env": None, + } + + settings_file = os.path.join(federation.config_dir, "settings.json") + with open(settings_file, "w") as f: + json.dump(settings, f, sort_keys=False, indent=2) + + os.chmod(settings_file, 0o777) + + # Attacks assigment and mobility + self.logger.info("🔧 Building general configuration") + sb.build_general_configuration() + self.logger.info("✅ Building general configuration done") + + # Create participant configs and .json + for index, (_, node) in enumerate(sb.get_federation_nodes().items()): + self.logger.info(f"Creating .json file for participant: {index}, Configuration: {node}") + node_config = node + try: + participant_file = os.path.join(federation.config_dir, f"participant_{node_config['id']}.json") + self.logger.info(f"Filename: {participant_file}") + os.makedirs(os.path.dirname(participant_file), exist_ok=True) + except Exception as e: + self.logger.info(f"ERROR while creating files: {e}") + + try: + participant_config = sb.build_scenario_config_for_node(index, node) + #self.logger.info(f"dictionary: {participant_config}") + except Exception as e: + self.logger.info(f"ERROR while building configuration for node: {e}") + + try: + with open(participant_file, "w") as f: + json.dump(participant_config, f, sort_keys=False, indent=2) + os.chmod(participant_file, 0o777) + except Exception as e: + self.logger.info(f"ERROR while dumping configuration into files: {e}") + + self.logger.info("✅ Initializing Scenario Builder done") + + async def _load_configuration_and_start_nodes(self, sb: ScenarioBuilder, federation: NebulaFederationProcesses): + self.logger.info("🔧 Loading Scenario configuration...") + # Get participants configurations + participant_files = glob.glob(f"{federation.config_dir}/participant_*.json") + participant_files.sort() + if len(participant_files) == 0: + raise ValueError("No participant files found in config folder") + + federation.config.set_participants_config(participant_files) + n_nodes = len(participant_files) + self.logger.info(f"Number of nodes: {n_nodes}") + + sb.create_topology_manager(federation.config) + + # Update participants configuration + is_start_node = False + config_participants = [] + + additional_participants = sb.get_additional_nodes() + additional_nodes = len(additional_participants) if additional_participants else 0 + + participant_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) + + # Initial participants + self.logger.info("🔧 Building preload configuration for initial nodes...") + for i in range(n_nodes): + try: + with open(f"{federation.config_dir}/participant_" + str(i) + ".json") as f: + participant_config = json.load(f) + except Exception as e: + self.logger.info(f"ERROR: open/load participant .json") + + self.logger.info(f"Building preload conf for participant {i}") + try: + sb.build_preload_initial_node_configuration(i, participant_config, federation.log_dir, federation.config_dir, self.cert_dir, self.advanced_analytics) + except Exception as e: + self.logger.info(f"ERROR: cannot build preload configuration") + + try: + with open(f"{federation.config_dir}/participant_" + str(i) + ".json", "w") as f: + json.dump(participant_config, f, sort_keys=False, indent=2) + except Exception as e: + self.logger.info(f"ERROR: cannot dump preload configuration into participant .json file") + + config_participants.append(( + participant_config["network_args"]["ip"], + participant_config["network_args"]["port"], + participant_config["device_args"]["role"], + )) + if participant_config["device_args"]["start"]: + if not is_start_node: + is_start_node = True + else: + raise ValueError("Only one node can be start node") + + self.logger.info("✅ Building preload configuration for initial nodes done") + + federation.config.set_participants_config(participant_files) + + # Add role to the topology (visualization purposes) + sb.visualize_topology(config_participants, path=f"{federation.config_dir}/topology.png", plot=False) + + # Additional participants + self.logger.info("🔧 Building preload configuration for additional nodes...") + additional_participants_files = [] + if additional_participants: + last_participant_file = participant_files[-1] + last_participant_index = len(participant_files) + + for i, _ in enumerate(additional_participants): + additional_participant_file = f"{federation.config_dir}/participant_{last_participant_index + i}.json" + shutil.copy(last_participant_file, additional_participant_file) + + with open(additional_participant_file) as f: + participant_config = json.load(f) + + self.logger.info(f"Configuration | additional nodes | participant: {n_nodes + i}") + sb.build_preload_additional_node_configuration(last_participant_index, i, participant_config) + + with open(additional_participant_file, "w") as f: + json.dump(participant_config, f, sort_keys=False, indent=2) + + additional_participants_files.append(additional_participant_file) + + if additional_participants_files: + federation.config.add_participants_config(additional_participants_files) + + if additional_participants: + n_nodes += len(additional_participants) + + self.logger.info("✅ Building preload configuration for additional nodes done") + self.logger.info("✅ Loading Scenario configuration done") + + # Build dataset + dataset = sb.configure_dataset(federation.config_dir) + self.logger.info(f"🔧 Splitting {sb.get_dataset_name()} dataset...") + dataset.initialize_dataset() + self.logger.info(f"✅ Splitting {sb.get_dataset_name()} dataset... Done") + + def _start_initial_nodes(self, sb: ScenarioBuilder, federation: NebulaFederationProcesses): + self.logger.info("Starting nodes as processes...") + self.logger.info(f"Number of participants: {len(federation.config.participants)}") + federation.config.participants.sort(key=lambda x: x["device_args"]["idx"]) + federation.last_index_deployed = 2 + + commands = "" + commands = self._build_initial_commands() + if not commands: + self.logger.info("ERROR: Cannot create commands file, abort..") + return + + for idx, node in enumerate(federation.config.participants): + if node["deployment_args"]["additional"]: + federation.additionals_participants[idx] = int(node["deployment_args"]["deployment_round"]) + federation.participants_alive += 1 + self.logger.info(f"Participant {idx} is additional. Round of deployment: {int(node['deployment_args']['deployment_round'])}") + else: + # deploy initial nodes + self.logger.info(f"Deployment starting for participant {idx}") + federation.round_per_participant[idx] = 0 + node_command = self._start_node(sb, node, federation.network_name, federation.base_network_name, federation.base, federation.last_index_deployed, federation) + commands += node_command + if node_command: + federation.last_index_deployed += 1 + federation.participants_alive += 1 + + if federation.config.participants and commands: + self._write_commands_on_file(commands, federation) + else: + self.logger.info("ERROR: No commands on a proccesses deployment..") + + def _start_node(self, sb: ScenarioBuilder, node, network_name, base_network_name, base, i, federation: NebulaFederationProcesses, additional=False): + self.processes_root_path = os.path.join(os.path.dirname(__file__), "..", "..") + node_idx = node['device_args']['idx'] + # Include additional config to the participants + node["tracking_args"]["log_dir"] = os.path.join(self.root_path, "app", "logs", sb.get_scenario_name(user_to=True)) + node["tracking_args"]["config_dir"] = os.path.join(self.root_path, "app", "config", sb.get_scenario_name(user_to=True)) + node["scenario_args"]["controller"] = self.url + node["scenario_args"]["deployment"] = sb.get_deployment() + node["security_args"]["certfile"] = os.path.join( + self.root_path, "app", "certs", f"participant_{node['device_args']['idx']}_cert.pem" + ) + node["security_args"]["keyfile"] = os.path.join( + self.root_path, "app", "certs", f"participant_{node['device_args']['idx']}_key.pem" + ) + node["security_args"]["cafile"] = os.path.join(self.root_path, "app", "certs", "ca_cert.pem") + # Write the config file in config directory + with open(f"{federation.config_dir}/participant_{node['device_args']['idx']}.json", "w") as f: + json.dump(node, f, indent=4) + + self.logger.info(f"Configuration file created successfully: {node_idx}") + commands = "" + try: + if self.host_platform == "windows": + if node["device_args"]["start"]: + commands += "Start-Sleep -Seconds 10\n" + else: + commands += "Start-Sleep -Seconds 2\n" + commands += f'Write-Host "Running node {node["device_args"]["idx"]}..."\n' + commands += f'$OUT_FILE = "{self.root_path}\\app\\logs\\{sb.get_scenario_name(user_to=True)}\\participant_{node["device_args"]["idx"]}.out"\n' + commands += f'$ERROR_FILE = "{self.root_path}\\app\\logs\\{sb.get_scenario_name(user_to=True)}\\participant_{node["device_args"]["idx"]}.err"\n' + # Use Start-Process for executing Python in background and capture PID + commands += f"""$process = Start-Process -FilePath "python" -ArgumentList "{self.root_path}\\nebula\\core\\node.py {self.root_path}\\app\\config\\{sb.get_scenario_name(user_to=True)}\\participant_{node["device_args"]["idx"]}.json" -PassThru -NoNewWindow -RedirectStandardOutput $OUT_FILE -RedirectStandardError $ERROR_FILE + Add-Content -Path $PID_FILE -Value $process.Id + """ + else: + if node["device_args"]["start"]: + commands += "sleep 10\n" + else: + commands += "sleep 2\n" + commands += f'echo "Running node {node["device_args"]["idx"]}..."\n' + commands += f"OUT_FILE={self.root_path}/app/logs/{sb.get_scenario_name(user_to=True)}/participant_{node['device_args']['idx']}.out\n" + commands += f"python {self.root_path}/nebula/core/node.py {self.root_path}/app/config/{sb.get_scenario_name(user_to=True)}/participant_{node['device_args']['idx']}.json &\n" + commands += "echo $! >> $PID_FILE\n\n" + except Exception as e: + raise Exception(f"Error starting nodes as processes: {e}") + + return commands + + def _build_initial_commands(self): + commands = "" + try: + if self.host_platform == "windows": + commands = """ + $ParentDir = Split-Path -Parent $PSScriptRoot + $PID_FILE = "$PSScriptRoot\\current_scenario_pids.txt" + New-Item -Path $PID_FILE -Force -ItemType File + + """ + else: + commands = '#!/bin/bash\n\nPID_FILE="$(dirname "$0")/current_scenario_pids.txt"\n\n> $PID_FILE\n\n' + except Exception as e: + raise Exception(f"Error starting nodes as processes: {e}") + return commands + + def _write_commands_on_file(self, commands: str, federation: NebulaFederationProcesses): + try: + if self.host_platform == "windows": + commands += 'Write-Host "All nodes started. PIDs stored in $PID_FILE"\n' + with open(f"{federation.config_dir}/current_scenario_commands.ps1", "w") as f: + #self.logger.info(f"Process commands: {commands}") + f.write(commands) + os.chmod(f"{federation.config_dir}/current_scenario_commands.ps1", 0o755) + else: + commands += 'echo "All nodes started. PIDs stored in $PID_FILE"\n' + with open(f"{federation.config_dir}/current_scenario_commands.sh", "w") as f: + #self.logger.info(f"Process commands: {commands}") + f.write(commands) + os.chmod(f"{federation.config_dir}/current_scenario_commands.sh", 0o755) + except Exception as e: + raise Exception(f"Error starting nodes as processes: {e}") diff --git a/nebula/controller/federation/factory_federation_controller.py b/nebula/controller/federation/factory_federation_controller.py new file mode 100644 index 000000000..1cdfc86b2 --- /dev/null +++ b/nebula/controller/federation/factory_federation_controller.py @@ -0,0 +1,15 @@ +from nebula.controller.federation.federation_controller import FederationController + +def federation_controller_factory(mode: str, wa_controller_url: str, logger) -> FederationController: + from nebula.controller.federation.controllers.docker_federation_controller import DockerFederationController + from nebula.controller.federation.controllers.processes_federation_controller import ProcessesFederationController + from nebula.controller.federation.controllers.physicall_federation_controller import PhysicalFederationController + + if mode == "docker": + return DockerFederationController(wa_controller_url, logger) + elif mode == "physical": + return PhysicalFederationController(wa_controller_url, logger) + elif mode == "process": + return ProcessesFederationController(wa_controller_url, logger) + else: + raise ValueError("Unknown federation mode") \ No newline at end of file diff --git a/nebula/controller/federation/federation_api.py b/nebula/controller/federation/federation_api.py new file mode 100644 index 000000000..f7513c828 --- /dev/null +++ b/nebula/controller/federation/federation_api.py @@ -0,0 +1,129 @@ +import argparse +import os +import logging +from fastapi import FastAPI, Body, Path, Request +from fastapi.concurrency import asynccontextmanager +from typing import Dict +from typing import Annotated +from functools import wraps +from fastapi import HTTPException +from nebula.utils import LoggerUtils +from nebula.controller.federation.federation_controller import FederationController +from nebula.controller.federation.factory_federation_controller import federation_controller_factory +from nebula.controller.federation.utils_requests import RemoveScenarioRequest, RunScenarioRequest, StopScenarioRequest, NodeUpdateRequest, NodeDoneRequest, Routes + +fed_controllers: Dict[str, FederationController] = {} + +@asynccontextmanager +async def lifespan(app: FastAPI): + log_path = os.environ.get("NEBULA_FEDERATION_CONTROLLER_LOG") + + # Configure and register the logger under the name "controller" + LoggerUtils.configure_logger(name="Federation-Controller", log_file=log_path) + + # Retrieve the logger by name + logger = logging.getLogger("Federation-Controller") + logger.info("Logger initialized for Federation Controller") + + # Create all controller types + hub_port = os.environ.get("NEBULA_CONTROLLER_PORT") + controller_host = os.environ.get("NEBULA_CONTROLLER_HOST") + hub_url = f"http://{controller_host}:{hub_port}" + + #["docker", "processes", "physical"] + for exp_type in ["docker", "process"]: + fed_controllers[exp_type] = federation_controller_factory(exp_type, hub_url, logger) + logger.info(f"{exp_type} Federation controller created.") + + yield + +app = FastAPI(lifespan=lifespan) + +@app.get("/") +async def read_root(): + """ + Root endpoint of the NEBULA Controller API. + + Returns: + dict: A welcome message indicating the API is accessible. + """ + logger = logging.getLogger("Federation-Controller") + logger.info("Test curl succesfull") + return {"message": "Welcome to the NEBULA Federation Controller API"} + +@app.post(Routes.RUN) +async def run_scenario(run_scenario_request: RunScenarioRequest): + global fed_controllers + experiment_type = run_scenario_request.scenario_data["deployment"] + logger = logging.getLogger("Federation-Controller") + logger.info(f"[API]: run experiment request for deployment type: {experiment_type}") + controller = fed_controllers.get(experiment_type, None) + if controller: + return await controller.run_scenario(run_scenario_request.federation_id, run_scenario_request.scenario_data, run_scenario_request.user) + else: + return {"message": "Experiment type not allowed"} + +@app.post(Routes.STOP) +async def stop_scenario( + federation_id: str, + stop_scenario_request: StopScenarioRequest +): + global fed_controllers + experiment_type = stop_scenario_request.experiment_type + controller = fed_controllers.get(experiment_type, None) + logger = logging.getLogger("Federation-Controller") + logger.info(f"[API]: stop experiment request for federation ID: {stop_scenario_request.federation_id}") + if controller: + return await controller.stop_scenario(federation_id) + else: + return {"message": "Experiment type not allowed"} + +@app.post(Routes.UPDATE) +async def update_nodes( + federation_id: str, + node_update_request: NodeUpdateRequest, +): + global fed_controllers + experiment_type = node_update_request.config["scenario_args"]["deployment"] + controller = fed_controllers.get(experiment_type, None) + if controller: + return await controller.update_nodes(federation_id, node_update_request) + else: + return {"message": "Experiment type not allowed on response for update message.."} + +@app.post(Routes.DONE) +async def node_done( + federation_id: str, + node_done_request: NodeDoneRequest, +): + global fed_controllers + experiment_type = node_done_request.deployment + controller = fed_controllers.get(experiment_type, None) + if controller: + return await controller.node_done(federation_id, node_done_request) + else: + return {"message": "Experiment type not allowed on responde for Node done message.."} + +@app.post(Routes.REMOVE) +async def scenario_remove( + federation_id: str, + remove_scenario_request: RemoveScenarioRequest, +): + global fed_controllers + experiment_type = remove_scenario_request.experiment_type + controller = fed_controllers.get(experiment_type, None) + if controller: + return await controller.remove_scenario(federation_id, remove_scenario_request) + else: + return {"message": "Experiment type not allowed on responde for scenario remove message.."} + +if __name__ == "__main__": + # Parse args from command line + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=5051, help="Port to run the Federation controller on.") + args = parser.parse_args() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=args.port) + + \ No newline at end of file diff --git a/nebula/controller/federation/federation_controller.py b/nebula/controller/federation/federation_controller.py new file mode 100644 index 000000000..290aae62d --- /dev/null +++ b/nebula/controller/federation/federation_controller.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from fastapi import Request +from typing import Dict +from nebula.controller.federation.scenario_builder import ScenarioBuilder +from nebula.controller.federation.utils_requests import NodeUpdateRequest, NodeDoneRequest, RemoveScenarioRequest +import logging + +class NebulaFederation(ABC): + pass + +class FederationController(ABC): + + def __init__(self, hub_url, logger): + self._logger: logging.Logger = logger + self._hub_url = hub_url + + @property + def logger(self): + return self._logger + + @abstractmethod + async def run_scenario(self, federation_id: str, scenario_data: Dict, user: str): + pass + + @abstractmethod + async def stop_scenario(self, federation_id: str): + pass + + @abstractmethod + async def update_nodes(self, federation_id: str, node_update_request: NodeUpdateRequest): + pass + + abstractmethod + async def node_done(self, federation_id: str, node_done_request: NodeDoneRequest): + pass + + abstractmethod + async def remove_scenario(self, federation_id: str, remove_scenario_request: RemoveScenarioRequest): + pass \ No newline at end of file diff --git a/nebula/controller/federation/scenario_builder.py b/nebula/controller/federation/scenario_builder.py new file mode 100644 index 000000000..ab4f2ffd9 --- /dev/null +++ b/nebula/controller/federation/scenario_builder.py @@ -0,0 +1,900 @@ +import logging +from datetime import datetime +import hashlib +import math +from collections import defaultdict +from nebula.addons.topologymanager import TopologyManager +from nebula.config.config import Config +from nebula.core.utils.certificate import generate_certificate +from nebula.core.datasets.nebuladataset import NebulaDataset, factory_nebuladataset, factory_dataset_setup + +class ScenarioBuilder(): + def __init__(self, federation_id, user): + self._scenario_data = None + self._config_setup = None + self.logger = logging.getLogger("Federation-Controller") + self._topology_manager: TopologyManager = None + self._scenario_name = "" + self._federation_id = federation_id + self._user = user + + @property + def sd(self): + """Scenario data dict""" + return self._scenario_data + + @property + def tm(self): + """Topology Manager""" + return self._topology_manager + + def get_scenario_name(self, user_to=False): + scenario_path = self._user+"_"+self._scenario_name if user_to else self._scenario_name + return scenario_path + + def set_scenario_data(self, scenario_data: dict): + self._scenario_data = scenario_data + federation_name = self.sd["federation"] + self._scenario_name = f"nebula_{federation_name}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}" + + def set_config_setup(self, setup: dict): + self._config_setup = setup + + def get_federation_nodes(self) -> dict: + return self.sd["nodes"] + + def get_additional_nodes(self): + return self.sd["additional_participants"] + + def get_dataset_name(self) -> str: + return self.sd["dataset"] + + def get_deployment(self) -> str: + return self.sd["deployment"] + + def get_scenario_info(self) -> dict: + return {"federation_id": self._federation_id, "start_time": datetime.now().strftime('%d/%m/%Y %H:%M:%S'), "alias": self.sd["scenario_title"] , "scenario_name": self._scenario_name} + + """ ############################### + # SCENARIO CONFIG NODE # + ############################### + """ + def build_general_configuration(self): + try: + self.sd["nodes"] = self._configure_nodes_attacks() + + if self.sd.get("mobility", None): + mobile_participants_percent = int(self.sd["mobile_participants_percent"]) + self.sd["nodes"] = self._mobility_assign(self.sd["nodes"], mobile_participants_percent) + else: + self.sd["nodes"] = self._mobility_assign(self.sd["nodes"], 0) + except Exception as e: + self.logger.info(f"ERROR: {e}") + + def _configure_nodes_attacks(self): + self.logger.info("Configurating node attacks...") + poisoned_node_percent = self.sd["attack_params"].get("poisoned_node_percent", 0) + poisoned_sample_percent = self.sd["attack_params"].get("poisoned_sample_percent", 0) + poisoned_noise_percent = self.sd["attack_params"].get("poisoned_noise_percent", 0) + + nodes = self.attack_node_assign( + self.sd.get("nodes"), + self.sd.get("federation"), + int(poisoned_node_percent), + int(poisoned_sample_percent), + int(poisoned_noise_percent), + self.sd.get("attack_params"), + ) + + self.logger.info("Configurating node attacks done") + return nodes + + def attack_node_assign( + self, + nodes, + federation, + poisoned_node_percent, + poisoned_sample_percent, + poisoned_noise_percent, + attack_params, + ): + """ + Assign and configure attack parameters to nodes within a federated learning network. + + This method: + - Validates input attack parameters and percentages. + - Determines which nodes will be marked as malicious based on the specified + poisoned node percentage and attack type. + - Assigns attack roles and parameters to selected nodes. + - Supports multiple attack types such as Label Flipping, Sample Poisoning, + Model Poisoning, GLL Neuron Inversion, Swapping Weights, Delayer, and Flooding. + - Ensures proper validation and setting of attack-specific parameters, including + targeting, noise types, delays, intervals, and attack rounds. + - Updates nodes' malicious status, reputation, and attack parameters accordingly. + + Args: + nodes (dict): Dictionary of nodes with their current attributes. + federation (str): Type of federated learning framework (e.g., "DFL"). + poisoned_node_percent (float): Percentage of nodes to be poisoned (0-100). + poisoned_sample_percent (float): Percentage of samples to be poisoned (0-100). + poisoned_noise_percent (float): Percentage of noise to apply in poisoning (0-100). + attack_params (dict): Dictionary containing attack type and associated parameters. + + Returns: + dict: Updated nodes dictionary with assigned malicious roles and attack parameters. + + Raises: + ValueError: If any input parameter is invalid or attack type is unrecognized. + """ + import random + + # Validate input parameters + def validate_percentage(value, name): + """ + Validate that a given value is a float percentage between 0 and 100. + + Args: + value: The value to validate, expected to be convertible to float. + name (str): Name of the parameter, used for error messages. + + Returns: + float: The validated percentage value. + + Raises: + ValueError: If the value is not a float or not within the range [0, 100]. + """ + try: + value = float(value) + if not 0 <= value <= 100: + raise ValueError(f"{name} must be between 0 and 100") + return value + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid {name}: {e!s}") + + def validate_positive_int(value, name): + """ + Validate that a given value is a positive integer (including zero). + + Args: + value: The value to validate, expected to be convertible to int. + name (str): Name of the parameter, used for error messages. + + Returns: + int: The validated positive integer value. + + Raises: + ValueError: If the value is not an integer or is negative. + """ + try: + value = int(value) + if value < 0: + raise ValueError(f"{name} must be positive") + return value + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid {name}: {e!s}") + + # Validate attack type + valid_attacks = { + "No Attack", + "Label Flipping", + "Sample Poisoning", + "Model Poisoning", + "GLL Neuron Inversion", + "Swapping Weights", + "Delayer", + "Flooding", + } + + # Get attack type from attack_params + if attack_params and "attacks" in attack_params: + attack = attack_params["attacks"] + + # Handle attack parameter which can be either a string or None + if attack is None: + attack = "No Attack" + elif not isinstance(attack, str): + raise ValueError(f"Invalid attack type: {attack}. Expected string or None.") + + if attack not in valid_attacks: + raise ValueError(f"Invalid attack type: {attack}. Must be one of {valid_attacks}") + + # Get attack parameters from attack_params + poisoned_node_percent = attack_params.get("poisoned_node_percent", poisoned_node_percent) + poisoned_sample_percent = attack_params.get("poisoned_sample_percent", poisoned_sample_percent) + poisoned_noise_percent = attack_params.get("poisoned_noise_percent", poisoned_noise_percent) + + # Validate percentage parameters + poisoned_node_percent = validate_percentage(poisoned_node_percent, "poisoned_node_percent") + poisoned_sample_percent = validate_percentage(poisoned_sample_percent, "poisoned_sample_percent") + poisoned_noise_percent = validate_percentage(poisoned_noise_percent, "poisoned_noise_percent") + + nodes_index = [] + # Get the nodes index + if federation == "DFL": + nodes_index = list(nodes.keys()) + else: + for node in nodes: + if nodes[node]["role"] != "server": + nodes_index.append(node) + + self.logger.info(f"Nodes index: {nodes_index}") + self.logger.info(f"Attack type: {attack}") + self.logger.info(f"Poisoned node percent: {poisoned_node_percent}") + + mal_nodes_defined = any(nodes[node]["malicious"] for node in nodes) + self.logger.info(f"Malicious nodes already defined: {mal_nodes_defined}") + + attacked_nodes = [] + + if not mal_nodes_defined and attack != "No Attack": + n_nodes = len(nodes_index) + # Number of attacked nodes, round up + num_attacked = int(math.ceil(poisoned_node_percent / 100 * n_nodes)) + if num_attacked > n_nodes: + num_attacked = n_nodes + + # Get the index of attacked nodes + attacked_nodes = random.sample(nodes_index, num_attacked) + self.logger.info(f"Number of nodes to attack: {num_attacked}") + self.logger.info(f"Attacked nodes: {attacked_nodes}") + + # Assign the role of each node + for node in nodes: + node_att = "No Attack" + malicious = False + #node_reputation = self.reputation.copy() if self.reputation else None + + if node in attacked_nodes or nodes[node]["malicious"]: + malicious = True + node_reputation = None + node_att = attack + self.logger.info(f"Node {node} marked as malicious with attack {attack}") + + # Initialize attack parameters with defaults + node_attack_params = attack_params.copy() if attack_params else {} + + # Set attack-specific parameters + if attack == "Label Flipping": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + node_attack_params["poisoned_sample_percent"] = poisoned_sample_percent + node_attack_params["targeted"] = attack_params.get("targeted", False) + if node_attack_params["targeted"]: + node_attack_params["target_label"] = validate_positive_int( + attack_params.get("target_label", 4), "target_label" + ) + node_attack_params["target_changed_label"] = validate_positive_int( + attack_params.get("target_changed_label", 7), "target_changed_label" + ) + + elif attack == "Sample Poisoning": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + node_attack_params["poisoned_sample_percent"] = poisoned_sample_percent + node_attack_params["poisoned_noise_percent"] = poisoned_noise_percent + node_attack_params["noise_type"] = attack_params.get("noise_type", "Gaussian") + node_attack_params["targeted"] = attack_params.get("targeted", False) + if node_attack_params["targeted"]: + node_attack_params["target_label"] = validate_positive_int( + attack_params.get("target_label", 4), "target_label" + ) + + elif attack == "Model Poisoning": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + node_attack_params["poisoned_noise_percent"] = poisoned_noise_percent + node_attack_params["noise_type"] = attack_params.get("noise_type", "Gaussian") + + elif attack == "GLL Neuron Inversion": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + + elif attack == "Swapping Weights": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + node_attack_params["layer_idx"] = validate_positive_int( + attack_params.get("layer_idx", 0), "layer_idx" + ) + + elif attack == "Delayer": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + node_attack_params["delay"] = validate_positive_int(attack_params.get("delay", 10), "delay") + node_attack_params["target_percentage"] = validate_percentage( + attack_params.get("target_percentage", 100), "target_percentage" + ) + node_attack_params["selection_interval"] = validate_positive_int( + attack_params.get("selection_interval", 1), "selection_interval" + ) + + elif attack == "Flooding": + node_attack_params["poisoned_node_percent"] = poisoned_node_percent + node_attack_params["flooding_factor"] = validate_positive_int( + attack_params.get("flooding_factor", 100), "flooding_factor" + ) + node_attack_params["target_percentage"] = validate_percentage( + attack_params.get("target_percentage", 100), "target_percentage" + ) + node_attack_params["selection_interval"] = validate_positive_int( + attack_params.get("selection_interval", 1), "selection_interval" + ) + + # Add common attack parameters + node_attack_params["round_start_attack"] = validate_positive_int( + attack_params.get("round_start_attack", 1), "round_start_attack" + ) + node_attack_params["round_stop_attack"] = validate_positive_int( + attack_params.get("round_stop_attack", 10), "round_stop_attack" + ) + node_attack_params["attack_interval"] = validate_positive_int( + attack_params.get("attack_interval", 1), "attack_interval" + ) + + # Validate round parameters + if node_attack_params["round_start_attack"] >= node_attack_params["round_stop_attack"]: + raise ValueError("round_start_attack must be less than round_stop_attack") + + node_attack_params["attacks"] = node_att + nodes[node]["malicious"] = True + nodes[node]["attack_params"] = node_attack_params + nodes[node]["fake_behavior"] = nodes[node]["role"] + nodes[node]["role"] = "malicious" + # else: + # nodes[node]["attack_params"] = {"attacks": "No Attack"} + + if nodes[node].get("attack_params", None): + self.logger.info( + f"Node {node} final configuration - malicious: {nodes[node]['malicious']}, attack: {nodes[node]['attack_params']['attacks']}" + ) + else: + self.logger.info( + f"Node {node} final configuration - malicious: {nodes[node]['malicious']}" + ) + + return nodes + + def _mobility_assign(self, nodes, mobile_participants_percent): + """ + Assign mobility status to a subset of nodes based on a specified percentage. + + This method: + - Calculates the number of mobile nodes by applying the given percentage. + - Randomly selects nodes to be marked as mobile. + - Updates each node's "mobility" attribute to True or False accordingly. + + Args: + nodes (dict): Dictionary of nodes with their current attributes. + mobile_participants_percent (float): Percentage of nodes to be assigned mobility (0-100). + + Returns: + dict: Updated nodes dictionary with mobility status assigned. + """ + import random + + # Number of mobile nodes, round down + num_mobile = math.floor(mobile_participants_percent / 100 * len(nodes)) + if num_mobile > len(nodes): + num_mobile = len(nodes) + + # Get the index of mobile nodes + mobile_nodes = random.sample(list(nodes.keys()), num_mobile) + + # Assign the role of each node + for node in nodes: + node_mob = False + if node in mobile_nodes: + node_mob = True + nodes[node]["mobility"] = node_mob + return nodes + + + """ ############################### + # SCENARIO CONFIG NODE # + ############################### + """ + + def build_scenario_config_for_node(self, index, node) -> dict: + self.logger.info(f"Start building the scenario configuration for participant {index}") + + def recursive_defaultdict(): + return defaultdict(recursive_defaultdict) + + def dictify(d): + if isinstance(d, defaultdict): + return {k: dictify(v) for k, v in d.items()} + return d + + participant_config = recursive_defaultdict() + + addons_config = defaultdict() + #participant_config["addons"] = dict() + + # General configuration + participant_config["scenario_args"]["name"] = self._scenario_name + participant_config["scenario_args"]["start_time"] = datetime.now().strftime("%d/%m/%Y %H:%M:%S") + participant_config["scenario_args"]["federation_id"] = self._federation_id + participant_config["deployment_args"]["additional"] = False + + node_config = node #self.sd["nodes"][index] + participant_config["network_args"]["ip"] = node_config["ip"] + if self.sd["deployment"] == "physical": + participant_config["network_args"]["port"] = 8000 + else: + participant_config["network_args"]["port"] = int(node_config["port"]) + + participant_config["network_args"]["simulation"] = self.sd["network_simulation"] + participant_config["device_args"]["idx"] = node_config["id"] + participant_config["device_args"]["start"] = node_config["start"] + participant_config["device_args"]["role"] = node_config["role"] + participant_config["device_args"]["proxy"] = node_config["proxy"] + participant_config["device_args"]["malicious"] = node_config["malicious"] + participant_config["scenario_args"]["rounds"] = int(self.sd["rounds"]) + participant_config["scenario_args"]["random_seed"] = 42 + participant_config["federation_args"]["round"] = 0 + participant_config["data_args"]["dataset"] = self.sd["dataset"] + participant_config["data_args"]["iid"] = self.sd["iid"] + participant_config["data_args"]["num_workers"] = 0 + participant_config["data_args"]["partition_selection"] = self.sd["partition_selection"] + participant_config["data_args"]["partition_parameter"] = self.sd["partition_parameter"] + participant_config["model_args"]["model"] = self.sd["model"] + participant_config["training_args"]["epochs"] = int(self.sd["epochs"]) + participant_config["training_args"]["trainer"] = "lightning" + participant_config["device_args"]["accelerator"] = self.sd["accelerator"] + participant_config["device_args"]["gpu_id"] = self.sd["gpu_id"] + participant_config["device_args"]["logging"] = self.sd["logginglevel"] + participant_config["aggregator_args"]["algorithm"] = self.sd["agg_algorithm"] + participant_config["aggregator_args"]["aggregation_timeout"] = 60 + + participant_config["message_args"]= self._configure_message_args() + participant_config["reporter_args"]= self._configure_reporter_args() + participant_config["forwarder_args"]= self._configure_forwarder_args() + participant_config["propagator_args"]= self._configure_propagator_args() + participant_config["misc_args"]= self._configure_misc_args() + + # Addons configuration + + # Trustworthiness + try: + if self.sd.get("with_trustworthiness", None): + addons_config["trustworthiness"] = self._configure_trustworthiness() + except Exception as e: + self.logger.info(f"ERROR: Cannot build trustworthiness configuration - {e}") + + # Reputation + try: + if self.sd.get("reputation", None) and self.sd["reputation"]["enabled"] and not node_config["role"] == "malicious": + addons_config["reputation"] = self._configure_reputation() + except Exception as e: + self.logger.info(f"ERROR: Cannot build reputation configuration - {e}") + + # Network simulation + try: + network_args: dict = (self.sd.get("network_args"), None) + if network_args and isinstance(network_args, dict) and network_args.get("enabled", None): + addons_config["network_simulation"] = self._configure_network_simulation() + except Exception as e: + self.logger.info(f"ERROR: Cannot build network simulation configuration - {e}") + + # Attacks + try: + if node_config["role"] == "malicious": + addons_config["adversarial_args"] = self._configure_malicious_role(node_config) + except Exception as e: + self.logger.info(f"ERROR: Cannot build role configuration - {e}") + + # Mobility + try: + if self.sd.get("mobility", None): + addons_config["mobility"] = self._configure_mobility_args() + except Exception as e: + self.logger.info(f"ERROR: Cannot build mobility configuration - {e}") + + # Situational awareness module + try: + if self._situational_awareness_needed(): + addons_config["situational_awareness"] = self._configure_situational_awareness(index) + except Exception as e: + self.logger.info(f"ERROR: Cannot build situational awareness configuration - {e}") + + # Addon addition to the configuration + participant_config["addons"] = addons_config + + try: + config = dictify(participant_config) + except Exception as e: + self.logger.info(f"ERROR: Translating into dictionary - {e}") + + return config + + def _configure_message_args(self): + return { + "max_local_messages": 10000, + "compression": "zlib" + } + + def _configure_reporter_args(self): + return { + "grace_time_reporter": 10, + "report_frequency": 5, + "report_status_data_queue": True + } + + def _configure_forwarder_args(self): + return { + "forwarder_interval": 1, + "forward_messages_interval": 0, + "number_forwarded_messages": 100 + } + + def _configure_propagator_args(self): + return { + "propagate_interval": 3, + "propagate_model_interval": 0, + "propagation_early_stop": 3, + "history_size": 20 + } + + def _configure_misc_args(self): + return { + "grace_time_connection": 10, + "grace_time_start_federation": 10 + } + + def _configure_mobility_args(self): + return { + "enabled": True, + "mobility_type": self.sd["mobility_type"], + "topology_type": self.sd["topology"], + "radius_federation": self.sd["radius_federation"], + "scheme_mobility": self.sd["scheme_mobility"], + "round_frequency": self.sd["round_frequency"], + "grace_time_mobility": 60, + "change_geo_interval": 5 + } + + def _configure_malicious_role(self, node_config: dict): + return { + "fake_behavior": node_config["fake_behavior"], + "attack_params": node_config["attack_params"] + } + + def _configure_trustworthiness(self) -> dict: + trust_config = { + "robustness_pillar": self.sd["robustness_pillar"], + "resilience_to_attacks": self.sd["resilience_to_attacks"], + "algorithm_robustness": self.sd["algorithm_robustness"], + "client_reliability": self.sd["client_reliability"], + "privacy_pillar": self.sd["privacy_pillar"], + "technique": self.sd["technique"], + "uncertainty": self.sd["uncertainty"], + "indistinguishability": self.sd["indistinguishability"], + "fairness_pillar": self.sd["fairness_pillar"], + "selection_fairness": self.sd["selection_fairness"], + "performance_fairness": self.sd["performance_fairness"], + "class_distribution": self.sd["class_distribution"], + "explainability_pillar": self.sd["explainability_pillar"], + "interpretability": self.sd["interpretability"], + "post_hoc_methods": self.sd["post_hoc_methods"], + "accountability_pillar": self.sd["accountability_pillar"], + "factsheet_completeness": self.sd["factsheet_completeness"], + "architectural_soundness_pillar": self.sd["architectural_soundness_pillar"], + "client_management": self.sd["client_management"], + "optimization": self.sd["optimization"], + "sustainability_pillar": self.sd["sustainability_pillar"], + "energy_source": self.sd["energy_source"], + "hardware_efficiency": self.sd["hardware_efficiency"], + "federation_complexity": self.sd["federation_complexity"], + "scenario": self.sd, + } + return trust_config + + def _configure_reputation(self) -> dict: + rep = self.sd.get("reputation") + rep["adaptive_args"] = True + return rep + + def _configure_network_simulation(self) -> dict: + network_parameters = {} + network_generation = dict(self.sd["network_args"]).pop("network_type") + enabled = dict(self.sd["network_args"]).pop("enabled") + type = dict(self.sd["network_args"]).pop("type") + addrs = "" + + for node in self.sd["nodes"]: + ip = self.sd["nodes"][node]["ip"] + port = self.sd["nodes"][node]["port"] + addrs = addrs + " " + f"{ip}:{port}" + + network_configuration = { + "interface": "eth0", + "verbose": False, + "preset": network_generation, + "federation": addrs + } + + network_parameters = { + "enabled": enabled, + "type": type, + "network_config": network_configuration + } + + return network_parameters + + def _situational_awareness_needed(self): + enabled = False + arrivals_dep = self.sd.get("arrivals_departures_args", None) + if arrivals_dep: + enabled = arrivals_dep["enabled"] + with_sa = self.sd.get("with_sa", None) + additionals = self.sd.get("additional_participants", None) + mob = self.sd.get("mobility", None) + + return with_sa or enabled or arrivals_dep or additionals or mob + + def _configure_situational_awareness(self, index) -> dict: + try: + scheduled_isolation = self._configure_arrivals_departures(index) + except Exception as e: + self.logger.info(f"ERROR: cannot configure arrival departures section - {e}") + + snp = self.sd.get("sar_neighbor_policy", None) + topology_management = snp if (snp != "") else self.sd["topology"] + + situational_awareness_config = { + "strict_topology": self.sd["strict_topology"], + "sa_discovery": { + "candidate_selector": topology_management, + "model_handler": self.sd["sad_model_handler"], + "verbose": True, + }, + "sa_reasoner": { + "arbitration_policy": self.sd["sar_arbitration_policy"], + "verbose": True, + "sar_components": { + "sa_network": True, + "sa_training": self.sd["sar_training"] + }, + "sa_network": { + "neighbor_policy": topology_management, + "scheduled_isolation" : scheduled_isolation, + "verbose": True + }, + "sa_training": { + "training_policy": self.sd["sar_training_policy"], + "verbose": True + }, + }, + } + return situational_awareness_config + + def _configure_arrivals_departures(self, index) -> dict: + arrival_dep_section = self.sd.get("arrivals_departures_args", None) + if not arrival_dep_section or (arrival_dep_section and not self.sd["arrivals_departures_args"]["enabled"]): + return {"enabled": False} + + config = {"enabled": True} + departures: list = self.sd["arrivals_departures_args"]["departures"] + index_departure_config: dict = departures[index] + if index_departure_config["round_start"] != "": + config["round_start"] = index_departure_config["round_start"] + config["duration"] = index_departure_config["duration"] if index_departure_config["duration"] != "" else None + else: + config = {"enabled": False} + + return config + + """ ############################### + # PRELOAD CONFIG # + ############################### + """ + + def build_preload_initial_node_configuration(self, index, participant_config: dict, log_dir, config_dir, cert_dir, advanced_analytics): + try: + participant_config["scenario_args"]["federation"] = self.sd["federation"] + n_nodes = len(self.sd["nodes"].keys()) + n_additionals = len(self.sd["additional_participants"]) + participant_config["scenario_args"]["n_nodes"] = n_nodes + n_additionals + + participant_config["network_args"]["neighbors"] = self.tm.get_neighbors_string(index) + + participant_config["device_args"]["idx"] = index + participant_config["device_args"]["uid"] = hashlib.sha1( + ( + str(participant_config["network_args"]["ip"]) + + str(participant_config["network_args"]["port"]) + + str(participant_config["scenario_args"]["name"]) + ).encode() + ).hexdigest() + except Exception as e: + self.logger.info(f"ERROR while setting up general stuff") + + try: + if participant_config.get("addons", None) and participant_config["addons"].get("mobility", None): + if participant_config["addons"]["mobility"].get("random_geo", None): + ( + participant_config["addons"]["mobility"]["latitude"], + participant_config["addons"]["mobility"]["longitude"], + ) = TopologyManager.get_coordinates(random_geo=True) + else: + participant_config["addons"]["mobility"]["latitude"] = self.sd["latitude"] + participant_config["addons"]["mobility"]["longitude"] = self.sd["longitude"] + except Exception as e: + self.logger.info(f"ERROR while setting up mobility parameters - {e}") + + try: + participant_config["tracking_args"] = {} + participant_config["security_args"] = {} + + # If not, use the given coordinates in the frontend + participant_config["tracking_args"]["local_tracking"] = "default" + participant_config["tracking_args"]["log_dir"] = log_dir + participant_config["tracking_args"]["config_dir"] = config_dir + # Generate node certificate + keyfile_path, certificate_path = generate_certificate( + dir_path=cert_dir, + node_id=f"participant_{index}", + ip=participant_config["network_args"]["ip"], + ) + participant_config["security_args"]["certfile"] = certificate_path + participant_config["security_args"]["keyfile"] = keyfile_path + except Exception as e: + self.logger.info(f"ERROR while setting up tracking args and certificates") + + def build_preload_additional_node_configuration(self, last_participant_index, index, participant_config): + n_nodes = len(self.sd["nodes"].keys()) + n_additionals = len(self.sd["additional_participants"]) + last_ip = participant_config["network_args"]["ip"] + participant_config["scenario_args"]["n_nodes"] = n_nodes + n_additionals # self.n_nodes + i + 1 + participant_config["device_args"]["idx"] = last_participant_index + index + participant_config["network_args"]["neighbors"] = "" + participant_config["network_args"]["ip"] = ( + participant_config["network_args"]["ip"].rsplit(".", 1)[0] + + "." + + str(int(participant_config["network_args"]["ip"].rsplit(".", 1)[1]) + index + 1) + ) + participant_config["device_args"]["uid"] = hashlib.sha1( + ( + str(participant_config["network_args"]["ip"]) + + str(participant_config["network_args"]["port"]) + + str(self._scenario_name) + ).encode() + ).hexdigest() + participant_config["deployment_args"]["additional"] = True + + deployment_round = self.sd["additional_participants"][index]["time_start"] + participant_config["deployment_args"]["deployment_round"] = deployment_round + + # used for late creation nodes + + """ ############################### + # TOPOLOGY MANAGER # + ############################### + """ + + def create_topology_manager(self, config: Config): + try: + self._topology_manager = ( + self._create_topology(config, matrix=self.sd["matrix"]) if self.sd["matrix"] else self._create_topology(config) + ) + except Exception as e: + self.logger.info(f"ERROR: cannot create topology manager - {e}") + + def _create_topology(self, config: Config, matrix=None): + """ + Create and return a network topology manager based on the scenario's topology settings or a given adjacency matrix. + + Supports multiple topology types: + - Random: Generates an Erdős-Rényi random graph with specified connection probability. + - Matrix: Uses a provided adjacency matrix to define the topology. + - Fully: Creates a fully connected network. + - Ring: Creates a ring-structured network with partial connectivity. + - Star: Creates a centralized star topology (only for CFL federation). + + The method assigns IP and port information to nodes and returns the configured TopologyManager instance. + + Args: + matrix (optional): Adjacency matrix to define custom topology. If provided, overrides scenario topology. + + Raises: + ValueError: If an unknown topology type is specified in the scenario. + + Returns: + TopologyManager: Configured topology manager with nodes assigned. + """ + import numpy as np + + n_nodes = len(self.sd["nodes"].keys()) + if self.sd["topology"] == "Random": + # Create network topology using topology manager (random) + probability = float(self.sd["random_topology_probability"]) + logging.info( + f"Creating random network topology using erdos_renyi_graph: nodes={n_nodes}, probability={probability}" + ) + topologymanager = TopologyManager( + scenario_name=self._scenario_name, + n_nodes=n_nodes, + b_symmetric=True, + undirected_neighbor_num=3, + ) + topologymanager.generate_random_topology(probability) + elif matrix is not None: + if n_nodes > 2: + topologymanager = TopologyManager( + topology=np.array(matrix), + scenario_name=self._scenario_name, + n_nodes=n_nodes, + b_symmetric=True, + undirected_neighbor_num=n_nodes - 1, + ) + else: + topologymanager = TopologyManager( + topology=np.array(matrix), + scenario_name=self._scenario_name, + n_nodes=n_nodes, + b_symmetric=True, + undirected_neighbor_num=2, + ) + elif self.sd["topology"] == "Fully": + # Create a fully connected network + topologymanager = TopologyManager( + scenario_name=self._scenario_name, + n_nodes=n_nodes, + b_symmetric=True, + undirected_neighbor_num=n_nodes - 1, + ) + topologymanager.generate_topology() + elif self.sd["topology"] == "Ring": + # Create a partially connected network (ring-structured network) + topologymanager = TopologyManager(scenario_name=self._scenario_name, n_nodes=n_nodes, b_symmetric=True) + topologymanager.generate_ring_topology(increase_convergence=True) + elif self.sd["topology"] == "Star" and self.sd["federation"] == "CFL": + # Create a centralized network + topologymanager = TopologyManager(scenario_name=self._scenario_name, n_nodes=n_nodes, b_symmetric=True) + topologymanager.generate_server_topology() + else: + top = self.sd["topology"] + raise ValueError(f"Unknown topology type: {top}") + + # Assign nodes to topology + nodes_ip_port = [] + config.participants.sort(key=lambda x: int(x["device_args"]["idx"])) + for i, node in enumerate(config.participants): + nodes_ip_port.append(( + node["network_args"]["ip"], + node["network_args"]["port"], + "undefined", + )) + + topologymanager.add_nodes(nodes_ip_port) + return topologymanager + + def visualize_topology(self, config_participants, path, plot): + try: + self.tm.update_nodes(config_participants) + self.tm.draw_graph(path=path, plot=plot) + except Exception as e: + self.logger.info(f"ERROR: cannot visualize topology - {e}") + + """ ############################### + # DATASET CONFIGURATION # + ############################### + """ + + def configure_dataset(self, config_dir) -> NebulaDataset: + try: + dataset_name = self.get_dataset_name() + dataset = factory_nebuladataset( + dataset_name, + **self._configure_dataset_config(dataset_name, config_dir) + ) + except Exception as e: + self.logger.info(f"ERROR: cannot configure dataset - {e}") + return dataset + + def _configure_dataset_config(self, dataset_name, config_dir): + num_classes = factory_dataset_setup(dataset_name) + n_nodes = len(self.sd["nodes"].keys()) + n_nodes += len(self.sd["additional_participants"]) + return { + "num_classes": num_classes, + "partitions_number": n_nodes, + "iid": self.sd["iid"], + "partition": self.sd["partition_selection"], + "partition_parameter": self.sd["partition_parameter"], + "seed": 42, + "config_dir": config_dir, + } diff --git a/nebula/controller/federation/utils_requests.py b/nebula/controller/federation/utils_requests.py new file mode 100644 index 000000000..e51d7f22a --- /dev/null +++ b/nebula/controller/federation/utils_requests.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel +from typing import Dict, Any + +class RunScenarioRequest(BaseModel): + scenario_data: Dict[str, Any] + user: str + federation_id: str + +class StopScenarioRequest(BaseModel): + experiment_type: str + federation_id: str + +class NodeUpdateRequest(BaseModel): + config: Dict[str, Any] = {} + +class NodeDoneRequest(BaseModel): + idx: int + deployment: str + name: str + federation_id: str + +class RemoveScenarioRequest(BaseModel): + experiment_type: str + user: str + scenario_name: str + +class Routes: + INIT = "/init" + RUN = "/scenarios/run" + STOP = "/scenarios/{federation_id}/stop" + UPDATE = "/nodes/{federation_id}/update" + DONE = "/nodes/{federation_id}/done" + FINISH = "/scenarios/{federation_id}/finish" + REMOVE = "scenario/{federation_id}/remove" + + @classmethod + def format(cls, route: str, **kwargs) -> str: + return getattr(cls, route).format(**kwargs) + +def factory_requests(resource: str, **kwargs) -> str: + try: + return Routes.format(resource.upper(), **kwargs) + except AttributeError: + raise ValueError(f"Resource not found: {resource}") + except KeyError as e: + raise ValueError(f"Missing parameter for route '{resource}': {e}") + + \ No newline at end of file diff --git a/nebula/controller/controller.py b/nebula/controller/hub.py similarity index 59% rename from nebula/controller/controller.py rename to nebula/controller/hub.py index a00d142d1..7602c1d27 100755 --- a/nebula/controller/controller.py +++ b/nebula/controller/hub.py @@ -1,12 +1,13 @@ import argparse import asyncio -import datetime +from datetime import datetime import importlib import ipaddress import json import logging import os import re +import copy from typing import Annotated import aiohttp @@ -14,10 +15,13 @@ import uvicorn from fastapi import Body, FastAPI, Request, status, HTTPException, Path, File, UploadFile from fastapi.concurrency import asynccontextmanager - -from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished from nebula.controller.http_helpers import remote_get, remote_post_form -from nebula.utils import DockerUtils +from nebula.utils import APIUtils, DockerUtils +import nebula.controller.federation.utils_requests as federation_requests +import nebula.controller.utils_requests as controller_requests + +# URL for the database API +DATABASE_API_URL = os.environ.get("NEBULA_DATABASE_API_URL", "http://nebula-database:5051") # Setup controller logger @@ -66,9 +70,6 @@ def format(self, record): return super().format(record) -os.environ["NEBULA_CONTROLLER_NAME"] = os.environ.get("USER") - - def configure_logger(controller_log): """ Configures the logging system for the controller. @@ -102,27 +103,28 @@ def configure_logger(controller_log): handler.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) logger.addHandler(handler) - @asynccontextmanager async def lifespan(app: FastAPI): - databases_dir: str = os.environ.get("NEBULA_DATABASES_DIR") + """ + Application lifespan context manager. + - Configures logging on startup. + """ + # Code to run on startup controller_log: str = os.environ.get("NEBULA_CONTROLLER_LOG") - - from nebula.controller.database import initialize_databases - - await initialize_databases(databases_dir) - configure_logger(controller_log) yield + # Code to run on shutdown + pass + # Initialize FastAPI app outside the Controller class app = FastAPI(lifespan=lifespan) # Define endpoints outside the Controller class -@app.get("/") +@app.get(controller_requests.Routes.INIT) async def read_root(): """ Root endpoint of the NEBULA Controller API. @@ -133,7 +135,7 @@ async def read_root(): return {"message": "Welcome to the NEBULA Controller API"} -@app.get("/status") +@app.get(controller_requests.Routes.STATUS) async def get_status(): """ Check the status of the NEBULA Controller API. @@ -144,7 +146,7 @@ async def get_status(): return {"status": "NEBULA Controller API is running"} -@app.get("/resources") +@app.get(controller_requests.Routes.RESOURCES) async def get_resources(): """ Get system resource usage including RAM and GPU memory usage. @@ -186,7 +188,7 @@ async def get_resources(): } -@app.get("/least_memory_gpu") +@app.get(controller_requests.Routes.LEAST_MEMORY_GPU) async def get_least_memory_gpu(): """ Identify the GPU with the highest memory usage above a threshold (50%). @@ -228,7 +230,7 @@ async def get_least_memory_gpu(): } -@app.get("/available_gpus/") +@app.get(controller_requests.Routes.AVAILABLE_GPUS) async def get_available_gpu(): """ Get the list of GPUs with memory usage below 5%. @@ -264,33 +266,31 @@ async def get_available_gpu(): def validate_physical_fields(data: dict): if data.get("deployment") != "physical": - return - + return + ips = data.get("physical_ips") if not ips: raise HTTPException( status_code=400, detail="physical deployment requires 'physical_ips'" ) - + if len(ips) != data.get("n_nodes"): raise HTTPException( status_code=400, detail="'physical_ips' must have the same length as 'n_nodes'" ) - + try: for ip in ips: - ipaddress.ip_address(ip) + ipaddress.ip_address(ip) print(ip) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) -@app.post("/scenarios/run") -async def run_scenario( - scenario_data: dict = Body(..., embed=True), role: str = Body(..., embed=True), user: str = Body(..., embed=True) -): +@app.post(controller_requests.Routes.RUN) +async def run_scenario(run_scenario_request: controller_requests.RunScenarioRequest): """ Launches a new scenario based on the provided configuration. @@ -302,113 +302,141 @@ async def run_scenario( Returns: str: The name of the scenario that was started. """ + import hashlib + def generate_id(value: str) -> str: + return hashlib.sha256(value.encode("utf-8")).hexdigest() - import subprocess + response = None - from nebula.controller.scenarios import ScenarioManagement + try: + fed_controller_port = os.environ.get("NEBULA_FEDERATION_CONTROLLER_PORT") + fed_controller_host = os.environ.get("NEBULA_CONTROLLER_HOST") + url_run_scenario = f"http://{fed_controller_host}:{fed_controller_port}" + federation_requests.factory_requests_path("run") + #init_fed_req = InitFederationRequest(experiment_type="docker") + federation_id = generate_id(f"nebula_{run_scenario_request.user}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}") + run_scenario_req = federation_requests.RunScenarioRequest(scenario_data=run_scenario_request.scenario_data, federation_id=federation_id, user=run_scenario_request.user) #TODO ID per experiment + #await APIUtils.post(url_init_fed_controller, init_fed_req.model_dump()) + response = await APIUtils.post(url_run_scenario, run_scenario_req.model_dump()) + except Exception as e: + logging.info(e) - validate_physical_fields(scenario_data) + # Unpack request data (role is intentionally ignored for now) + scenario_data = run_scenario_request.scenario_data + user = run_scenario_request.user + + # validate_physical_fields(scenario_data) # Manager for the actual scenario - scenarioManagement = ScenarioManagement(scenario_data, user) - - await update_scenario( - scenario_name=scenarioManagement.scenario_name, - start_time=scenarioManagement.start_date_scenario, - end_time="", - scenario=scenario_data, - status="running", - role=role, - username=user, - ) + #scenarioManagement = ScenarioManagement(scenario_data, user) - # Run the actual scenario - try: - if scenarioManagement.scenario.mobility: - additional_participants = scenario_data["additional_participants"] - schema_additional_participants = scenario_data["schema_additional_participants"] - await scenarioManagement.load_configurations_and_start_nodes( - additional_participants, schema_additional_participants + if response: + try: + payload = controller_requests.ScenarioUpdateRequest( + alias=response["alias"], + scenario_name=response["scenario_name"], + start_time=response["start_time"], + end_time="", + scenario=run_scenario_request.scenario_data, + status="running", + username=run_scenario_request.user, + ).model_dump() + path = controller_requests.factory_requests_path( + "update", federation_id=federation_id ) - else: - await scenarioManagement.load_configurations_and_start_nodes() - except subprocess.CalledProcessError as e: - logging.exception(f"Error docker-compose up: {e}") - return - - return scenarioManagement.scenario_name - - -@app.post("/scenarios/stop") + await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) + return response["federation_id"] + except Exception as e: + logging.info(e) + else: + raise HTTPException(500, detail={"failed running scenario"}) + +@app.post(controller_requests.Routes.STOP) # TODO redo method async def stop_scenario( - scenario_name: str = Body(..., embed=True), - username: str = Body(..., embed=True), + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Federation identifier", + ), + ], all: bool = Body(False, embed=True), ): """ Stops the execution of a federated learning scenario and performs cleanup operations. - + This endpoint: - Stops all participant containers associated with the specified scenario. - Removes Docker containers and network resources tied to the scenario and user. - Sets the scenario's status to "finished" in the database. - Optionally finalizes all active scenarios if the 'all' flag is set. - + Args: - scenario_name (str): Name of the scenario to stop. - username (str): User who initiated the stop operation. + federation_id (str): Identifier of the scenario to stop. all (bool): Whether to stop all running scenarios instead of just one (default: False). - + Raises: HTTPException: Returns a 500 status code if any step fails. - + Note: This function does not currently trigger statistics generation. """ - from nebula.controller.scenarios import ScenarioManagement - - # ScenarioManagement.stop_participants(scenario_name) - DockerUtils.remove_containers_by_prefix(f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{username}-participant") - DockerUtils.remove_docker_network( - f"{(os.environ.get('NEBULA_CONTROLLER_NAME'))}_{str(username).lower()}-nebula-net-scenario" + fed_controller_port = os.environ.get("NEBULA_FEDERATION_CONTROLLER_PORT") + fed_controller_host = os.environ.get("NEBULA_CONTROLLER_HOST") + url_stop_scenario = ( + f"http://{fed_controller_host}:{fed_controller_port}" + + federation_requests.factory_requests_path("stop") + ) + stop_scenario_req = federation_requests.StopScenarioRequest( + experiment_type="docker", federation_id=federation_id ) try: - if all: - scenario_set_all_status_to_finished() - else: - scenario_set_status_to_finished(scenario_name) + path = controller_requests.factory_requests_path( + "stop", federation_id=federation_id + ) + payload = controller_requests.ScenarioStopRequest(all=all).model_dump() + await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) + await APIUtils.post(url_stop_scenario, stop_scenario_req.model_dump()) except Exception as e: - logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") - raise HTTPException(status_code=500, detail="Internal server error") + logging.info(f"ERROR: sending stop scenario to federation Controller: {e}") -@app.post("/scenarios/remove") +@app.post(controller_requests.Routes.REMOVE) async def remove_scenario( - scenario_name: str = Body(..., embed=True), + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Federation identifier", + ), + ], + request: controller_requests.ScenarioRemoveRequest, ): """ - Removes a scenario from the database by its name. - - Args: - scenario_name (str): Name of the scenario to remove. - - Returns: - dict: A message indicating successful removal. + Removes a scenario from the database by its federation identifier. """ - from nebula.controller.database import remove_scenario_by_name from nebula.controller.scenarios import ScenarioManagement try: - remove_scenario_by_name(scenario_name) - ScenarioManagement.remove_files_by_scenario(scenario_name) + path = controller_requests.factory_requests_path( + "remove", federation_id=federation_id + ) + await APIUtils.post(f"{DATABASE_API_URL}{path}") + ScenarioManagement.remove_files_by_scenario(request.scenario_name) + except Exception as e: - logging.exception(f"Error removing scenario {scenario_name}: {e}") + logging.exception( + f"Error removing scenario {request.scenario_name} ({federation_id}): {e}" + ) raise HTTPException(status_code=500, detail="Internal server error") - return {"message": f"Scenario {scenario_name} removed successfully"} + return {"message": f"Scenario {request.scenario_name} removed successfully"} -@app.get("/scenarios/{user}/{role}") +@app.get(controller_requests.Routes.GET_SCENARIOS_BY_USER) async def get_scenarios( user: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid username")], role: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid role")], @@ -423,89 +451,88 @@ async def get_scenarios( Returns: dict: A list of scenarios and the currently running scenario. """ - from nebula.controller.database import get_all_scenarios_and_check_completed, get_running_scenario - try: - scenarios = get_all_scenarios_and_check_completed(username=user, role=role) - if role == "admin": - scenario_running = get_running_scenario() - else: - scenario_running = get_running_scenario(username=user) + path = controller_requests.factory_requests_path("get_scenarios_by_user", user=user, role=role) + return await APIUtils.get(f"{DATABASE_API_URL}{path}") except Exception as e: logging.exception(f"Error obtaining scenarios: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"scenarios": scenarios, "scenario_running": scenario_running} - -@app.post("/scenarios/update") +@app.post(controller_requests.Routes.UPDATE) async def update_scenario( - scenario_name: str = Body(..., embed=True), - start_time: str = Body(..., embed=True), - end_time: str = Body(..., embed=True), - scenario: dict = Body(..., embed=True), - status: str = Body(..., embed=True), - role: str = Body(..., embed=True), - username: str = Body(..., embed=True), + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Federation identifier", + ), + ], + request: controller_requests.ScenarioUpdateRequest, ): """ - Updates the status and metadata of a scenario. + Updates the status and metadata of a scenario identified by its federation ID. Args: - scenario_name (str): Name of the scenario. - start_time (str): Start time of the scenario. - end_time (str): End time of the scenario. - scenario (dict): Scenario configuration. - status (str): New status of the scenario (e.g., "running", "finished"). - role (str): Role associated with the scenario. - username (str): User performing the update. + federation_id (str): Identifier of the scenario to update. + request (ScenarioUpdateRequest): Payload containing alias, scenario name, timing, configuration, status and username. Returns: dict: A message confirming the update. """ - from nebula.controller.database import scenario_update_record - from nebula.controller.scenarios import Scenario - try: - scenario = Scenario.from_dict(scenario) - scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username) + payload = request.model_dump() + path = controller_requests.factory_requests_path( + "update", federation_id=federation_id + ) + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: - logging.exception(f"Error updating scenario {scenario_name}: {e}") + logging.exception( + f"Error updating scenario {request.scenario_name} ({federation_id}): {e}" + ) raise HTTPException(status_code=500, detail="Internal server error") - return {"message": f"Scenario {scenario_name} updated successfully"} - -@app.post("/scenarios/set_status_to_finished") +@app.post(controller_requests.Routes.FINISH) async def set_scenario_status_to_finished( - scenario_name: str = Body(..., embed=True), all: bool = Body(False, embed=True) + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Federation identifier", + ), + ], + all: bool = Body(False, embed=True), ): """ Sets the status of a scenario (or all scenarios) to 'finished'. Args: - scenario_name (str): Name of the scenario to mark as finished. + federation_id (str): Identifier of the scenario to mark as finished. all (bool): If True, sets all scenarios to finished. Returns: dict: A message confirming the operation. """ - from nebula.controller.database import scenario_set_all_status_to_finished, scenario_set_status_to_finished - try: - if all: - scenario_set_all_status_to_finished() - else: - scenario_set_status_to_finished(scenario_name) + payload = controller_requests.ScenarioFinishRequest(all=all).model_dump() + path = controller_requests.factory_requests_path( + "finish", federation_id=federation_id + ) + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: - logging.exception(f"Error setting scenario {scenario_name} to finished: {e}") + logging.exception( + f"Error setting scenario {federation_id} to finished: {e}" + ) raise HTTPException(status_code=500, detail="Internal server error") - return {"message": f"Scenario {scenario_name} status set to finished successfully"} - -@app.get("/scenarios/running") -async def get_running_scenario(get_all: bool = False): +@app.get(controller_requests.Routes.RUNNING) +async def get_running_scenario_endpoint(get_all: bool = False): """ Retrieves the currently running scenario(s). @@ -515,20 +542,20 @@ async def get_running_scenario(get_all: bool = False): Returns: dict or list: Running scenario(s) information. """ - from nebula.controller.database import get_running_scenario - try: - return get_running_scenario(get_all=get_all) + path = controller_requests.factory_requests_path("running") + return await APIUtils.get(f"{DATABASE_API_URL}{path}", params={"get_all": str(get_all)}) except Exception as e: logging.exception(f"Error obtaining running scenario: {e}") raise HTTPException(status_code=500, detail="Internal server error") -@app.get("/scenarios/check/{role}/{scenario_name}") +@app.get(controller_requests.Routes.CHECK_SCENARIO) async def check_scenario( + user: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid username")], role: Annotated[str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid role")], - scenario_name: Annotated[ - str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") + federation_id: Annotated[ + str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=64, description="Valid federation identifier") ], ): """ @@ -541,73 +568,91 @@ async def check_scenario( Returns: dict: Whether the scenario is allowed for the role. """ - from nebula.controller.database import check_scenario_with_role - try: - allowed = check_scenario_with_role(role, scenario_name) - return {"allowed": allowed} + path = controller_requests.factory_requests_path( + "check_scenario", + user=user, + role=role, + federation_id=federation_id, + ) + return await APIUtils.get(f"{DATABASE_API_URL}{path}") except Exception as e: logging.exception(f"Error checking scenario with role: {e}") raise HTTPException(status_code=500, detail="Internal server error") -@app.get("/scenarios/{scenario_name}") -async def get_scenario_by_name( - scenario_name: Annotated[ - str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") +@app.get(controller_requests.Routes.GET_SCENARIO_BY_FEDERATION_ID) +async def get_scenario_by_federation_id( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), ], ): """ - Fetches a scenario by its name. + Fetches a scenario by its federation identifier. Args: - scenario_name (str): The name of the scenario. + federation_id (str): The identifier of the scenario. Returns: - dict: The scenario data. + dict: The scenario data returned by the database API. """ - from nebula.controller.database import get_scenario_by_name - try: - scenario = get_scenario_by_name(scenario_name) + path = controller_requests.factory_requests_path( + "get_scenarios_by_scenario_name", federation_id=federation_id + ) + return await APIUtils.get(f"{DATABASE_API_URL}{path}") except Exception as e: - logging.exception(f"Error obtaining scenario {scenario_name}: {e}") + logging.exception(f"Error obtaining scenario {federation_id}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return scenario - -@app.get("/nodes/{scenario_name}") -async def list_nodes_by_scenario_name( - scenario_name: Annotated[ - str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") +@app.get(controller_requests.Routes.NODES_BY_FEDERATION_ID) +async def list_nodes_by_federation_id_endpoint( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), ], ): """ - Lists all nodes associated with a specific scenario. + Lists all nodes associated with a specific federation identifier. Args: - scenario_name (str): Name of the scenario. + federation_id (str): Identifier of the scenario whose nodes should be listed. Returns: list: List of nodes. """ - from nebula.controller.database import list_nodes_by_scenario_name - try: - nodes = list_nodes_by_scenario_name(scenario_name) + path = controller_requests.factory_requests_path( + "get_nodes_by_scenario_name", federation_id=federation_id + ) + return await APIUtils.get(f"{DATABASE_API_URL}{path}") except Exception as e: - logging.exception(f"Error obtaining nodes: {e}") + logging.exception(f"Error obtaining nodes for {federation_id}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return nodes - -@app.post("/nodes/{scenario_name}/update") +@app.post(controller_requests.Routes.NODES_UPDATE_BY_FEDERATION) async def update_nodes( - scenario_name: Annotated[ + federation_id: Annotated[ str, - Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name"), + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), ], request: Request, ): @@ -615,55 +660,45 @@ async def update_nodes( Updates the configuration of a node in the database and notifies the frontend. Args: - scenario_name (str): The scenario to which the node belongs. + federation_id (str): Identifier of the scenario to update. request (Request): The HTTP request containing the node data. Returns: dict: Confirmation or response from the frontend. """ - from nebula.controller.database import update_node_record - try: - config = await request.json() - timestamp = datetime.datetime.now() - # Update the node in database - await update_node_record( - str(config["device_args"]["uid"]), - str(config["device_args"]["idx"]), - str(config["network_args"]["ip"]), - str(config["network_args"]["port"]), - str(config["device_args"]["role"]), - str(config["network_args"]["neighbors"]), - str(config["mobility_args"]["latitude"]), - str(config["mobility_args"]["longitude"]), - str(timestamp), - str(config["scenario_args"]["federation"]), - str(config["federation_args"]["round"]), - str(config["scenario_args"]["name"]), - str(config["tracking_args"]["run_hash"]), - str(config["device_args"]["malicious"]), - ) + config:dict = await request.json() + config["timestamp"] = str(datetime.now()) + + mobility_args = config.get("mobility_args", None) + if not mobility_args: + # default Murcia coordinates if none provided + config["mobility_args"] = {"latitude": "38.0235", "longitude": "-1.1744"} + # Validate and normalize payload + validated = controller_requests.NodesUpdateRequest(**config) + + # Build payload and include extras with mobility data + payload = validated.model_dump() + payload["extras"] = payload.get("mobility_args", {}) + payload.setdefault("scenario_args", {}) + payload["scenario_args"]["federation"] = federation_id + + # Update the node in database with validated data and extras + path = controller_requests.factory_requests_path("update_nodes") + await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: logging.exception(f"Error updating nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") + scenario_name = validated.scenario_args.name url = ( - f"http://{os.environ['NEBULA_CONTROLLER_NAME']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" + f"http://{os.environ['NEBULA_ENV_TAG']}_{os.environ['NEBULA_PREFIX_TAG']}_{os.environ['NEBULA_USER_TAG']}_nebula-frontend/platform/dashboard/{scenario_name}/node/update" ) - config["timestamp"] = str(timestamp) - - async with aiohttp.ClientSession() as session: - async with session.post(url, json=config) as response: - if response.status == 200: - return await response.json() - else: - raise HTTPException(status_code=response.status, detail="Error posting data") + return await APIUtils.post(url, data=config) - return {"message": "Nodes updated successfully in the database"} - -@app.post("/nodes/{scenario_name}/done") +@app.post(controller_requests.Routes.NODES_DONE_BY_SCENARIO) async def node_done( scenario_name: Annotated[ str, @@ -683,110 +718,129 @@ async def node_done( Returns the response from the frontend or raises an HTTPException if it fails. """ - url = f"http://{os.environ['NEBULA_CONTROLLER_NAME']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" + url = f"http://{os.environ['NEBULA_ENV_TAG']}_{os.environ['NEBULA_PREFIX_TAG']}_{os.environ['NEBULA_USER_TAG']}_nebula-frontend/platform/dashboard/{scenario_name}/node/done" data = await request.json() - async with aiohttp.ClientSession() as session: - async with session.post(url, json=data) as response: - if response.status == 200: - return await response.json() - else: - raise HTTPException(status_code=response.status, detail="Error posting data") - - return {"message": "Nodes done"} + return await APIUtils.post(url, data=data) -@app.post("/nodes/remove") -async def remove_nodes_by_scenario_name(scenario_name: str = Body(..., embed=True)): +@app.post(controller_requests.Routes.NODES_REMOVE) +async def remove_nodes_by_federation_id_endpoint( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), + ] +): """ - Endpoint to remove all nodes associated with a scenario. - - Body Parameters: - - scenario_name: Name of the scenario whose nodes should be removed. + Endpoint to remove all nodes associated with a scenario identified by federation ID. Returns a success message or an error if something goes wrong. """ - from nebula.controller.database import remove_nodes_by_scenario_name - try: - remove_nodes_by_scenario_name(scenario_name) + path = controller_requests.factory_requests_path( + "remove_nodes", federation_id=federation_id + ) + await APIUtils.post(f"{DATABASE_API_URL}{path}") except Exception as e: logging.exception(f"Error removing nodes: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"message": f"Nodes for scenario {scenario_name} removed successfully"} + return {"message": f"Nodes for federation {federation_id} removed successfully"} -@app.get("/notes/{scenario_name}") -async def get_notes_by_scenario_name( - scenario_name: Annotated[ - str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") +@app.get(controller_requests.Routes.NOTES_BY_FEDERATION_ID) +async def get_notes_by_federation_id( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), ], ): """ Endpoint to retrieve notes associated with a scenario. - - Path Parameters: - - scenario_name: Name of the scenario. - - Returns the notes or raises an HTTPException on error. """ - from nebula.controller.database import get_notes - try: - notes = get_notes(scenario_name) + path = controller_requests.factory_requests_path( + "get_notes_by_scenario_name", federation_id=federation_id + ) + return await APIUtils.get(f"{DATABASE_API_URL}{path}") except Exception as e: - logging.exception(f"Error obtaining notes {notes}: {e}") + logging.exception(f"Error obtaining notes for federation {federation_id}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return notes - -@app.post("/notes/update") -async def update_notes_by_scenario_name(scenario_name: str = Body(..., embed=True), notes: str = Body(..., embed=True)): +@app.post(controller_requests.Routes.NOTES_UPDATE) +async def update_notes_by_federation_id( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), + ], + notes: str = Body(..., embed=True), +): """ Endpoint to update notes for a given scenario. Body Parameters: - - scenario_name: Name of the scenario. - notes: Text content to store as notes. Returns a success message or an error if something goes wrong. """ - from nebula.controller.database import save_notes - try: - save_notes(scenario_name, notes) + payload = controller_requests.NotesUpdateRequest(notes=notes).model_dump() + path = controller_requests.factory_requests_path( + "update_notes", federation_id=federation_id + ) + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: - logging.exception(f"Error updating notes: {e}") + logging.exception(f"Error updating notes for federation {federation_id}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"message": f"Notes for scenario {scenario_name} updated successfully"} - -@app.post("/notes/remove") -async def remove_notes_by_scenario_name(scenario_name: str = Body(..., embed=True)): +@app.post(controller_requests.Routes.NOTES_REMOVE) +async def remove_notes_by_federation_id_endpoint( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), + ] +): """ - Endpoint to remove notes associated with a scenario. - - Body Parameters: - - scenario_name: Name of the scenario. + Endpoint to remove notes associated with a scenario identified by federation ID. Returns a success message or an error if something goes wrong. """ - from nebula.controller.database import remove_note - try: - remove_note(scenario_name) + path = controller_requests.factory_requests_path( + "remove_notes", federation_id=federation_id + ) + await APIUtils.post(f"{DATABASE_API_URL}{path}") except Exception as e: - logging.exception(f"Error removing notes: {e}") + logging.exception(f"Error removing notes for federation {federation_id}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return {"message": f"Notes for scenario {scenario_name} removed successfully"} + return {"message": f"Notes for federation {federation_id} removed successfully"} -@app.get("/user/list") +@app.get(controller_requests.Routes.USER_LIST) async def list_users_controller(all_info: bool = False): """ Endpoint to list all users in the database. @@ -796,44 +850,45 @@ async def list_users_controller(all_info: bool = False): Returns a list of users or raises an HTTPException on error. """ - from nebula.controller.database import list_users - try: - user_list = list_users(all_info) - if all_info: - # Convert each sqlite3.Row to a dictionary so that it is JSON serializable. - user_list = [dict(user) for user in user_list] - return {"users": user_list} + path = controller_requests.factory_requests_path("list_users") + return await APIUtils.get(f"{DATABASE_API_URL}{path}", params={"all_info": str(all_info)}) except Exception as e: + logging.exception(f"Error retrieving users: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving users: {e}") -@app.get("/user/{scenario_name}") -async def get_user_by_scenario_name( - scenario_name: Annotated[ - str, Path(regex="^[a-zA-Z0-9_-]+$", min_length=1, max_length=50, description="Valid scenario name") +@app.get(controller_requests.Routes.USER_BY_FEDERATION_ID) +async def get_user_by_federation_id_endpoint( + federation_id: Annotated[ + str, + Path( + regex="^[a-zA-Z0-9_-]+$", + min_length=1, + max_length=64, + description="Valid federation identifier", + ), ], ): """ - Endpoint to retrieve the user assigned to a scenario. + Endpoint to retrieve the user assigned to a scenario identified by federation ID. Path Parameters: - - scenario_name: Name of the scenario. + - federation_id: Identifier of the scenario. Returns user info or raises an HTTPException on error. """ - from nebula.controller.database import get_user_by_scenario_name - try: - user = get_user_by_scenario_name(scenario_name) + path = controller_requests.factory_requests_path( + "get_user_by_scenario_name", federation_id=federation_id + ) + return await APIUtils.get(f"{DATABASE_API_URL}{path}") except Exception as e: - logging.exception(f"Error obtaining user {user}: {e}") + logging.exception(f"Error obtaining user for federation {federation_id}: {e}") raise HTTPException(status_code=500, detail="Internal server error") - return user - -@app.get("/discover-vpn") +@app.get(controller_requests.Routes.DISCOVER_VPN) async def discover_vpn(): """ Calls the Tailscale CLI to fetch the current status in JSON format, @@ -847,45 +902,45 @@ async def discover_vpn(): stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) - + # 2) Wait for it to finish and capture stdout/stderr out, err = await proc.communicate() if proc.returncode != 0: # If the CLI returned an error, raise to be caught below raise RuntimeError(err.decode()) - + # 3) Parse the JSON output data = json.loads(out.decode()) - + # 4) Collect only the IPv4 addresses from each peer ips = [] for peer in data.get("Peer", {}).values(): for ip in peer.get("TailscaleIPs", []): - if ":" not in ip: + if ":" not in ip: # Skip IPv6 entries (they contain colons) ips.append(ip) - + # 5) Return the list of IPv4s return {"ips": ips} - + except Exception as e: # 6) Log any failure and respond with HTTP 500 logging.error(f"Error discovering VPN devices: {e}") raise HTTPException(status_code=500, detail="No devices discovered") -@app.get("/physical/run/{ip}", tags=["physical"]) +@app.get(controller_requests.Routes.PHYSICAL_RUN, tags=["physical"]) async def physical_run(ip: str): status, data = await remote_get(ip, "/run/") - + if status == 200: return data if status is None: raise HTTPException(status_code=502, detail=f"Node unreachable: {data}") raise HTTPException(status_code=status, detail=data) - - -@app.get("/physical/stop/{ip}", tags=["physical"]) + + +@app.get(controller_requests.Routes.PHYSICAL_STOP, tags=["physical"]) async def physical_stop(ip: str): status, data = await remote_get(ip, "/stop/") if status == 200: @@ -893,9 +948,9 @@ async def physical_stop(ip: str): if status is None: raise HTTPException(status_code=502, detail=f"Node unreachable: {data}") raise HTTPException(status_code=status, detail=data) - - -@app.put("/physical/setup/{ip}", tags=["physical"], + + +@app.put(controller_requests.Routes.PHYSICAL_SETUP, tags=["physical"], status_code=status.HTTP_201_CREATED) async def physical_setup( ip: str, @@ -903,7 +958,7 @@ async def physical_setup( global_test: UploadFile = File(..., description="Global Dataset*.h5*"), train_set: UploadFile = File(..., description="Training dataset*.h5*"), ): - + form = aiohttp.FormData() await config.seek(0) form.add_field("config", config.file, @@ -914,40 +969,40 @@ async def physical_setup( await train_set.seek(0) form.add_field("train_set", train_set.file, filename=train_set.filename, content_type="application/octet-stream") - + status_code, data = await remote_post_form( ip, "/setup/", form, method="PUT" ) - + if status_code == 201: return data if status_code is None: raise HTTPException(status_code=502, detail=f"Node unreachable: {data}") raise HTTPException(status_code=status_code, detail=data) - + # ────────────────────────────────────────────────────────────── # Physical · single-node state # ────────────────────────────────────────────────────────────── -@app.get("/physical/state/{ip}", tags=["physical"]) +@app.get(controller_requests.Routes.PHYSICAL_STATE, tags=["physical"]) async def get_physical_node_state(ip: str): """ Query a single Raspberry Pi (or other node) for its training state. - + Parameters ---------- ip : str IP address or hostname of the node. - + Returns ------- dict - • running (bool) – True if a training process is active. + • running (bool) – True if a training process is active. • error (str) – Optional error message when the node is unreachable or returns a non-200 HTTP status. """ # Short global timeout so a dead node doesn't block the whole request timeout = aiohttp.ClientTimeout(total=3) # seconds - + try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get(f"http://{ip}/state/") as resp: @@ -960,21 +1015,21 @@ async def get_physical_node_state(ip: str): except Exception as exc: # Network errors, timeouts, DNS failures, … return {"running": False, "error": str(exc)} - - + + # ────────────────────────────────────────────────────────────── # Physical · aggregate state for an entire scenario # ────────────────────────────────────────────────────────────── -@app.get("/physical/scenario-state/{scenario_name}", tags=["physical"]) -async def get_physical_scenario_state(scenario_name: str): +@app.get(controller_requests.Routes.PHYSICAL_SCENARIO_STATE, tags=["physical"]) +async def get_physical_scenario_state(federation_id: str): """ Check the training state of *every* physical node assigned to a scenario. - + Parameters ---------- - scenario_name : str + federation_id : str Scenario identifier. - + Returns ------- dict @@ -986,19 +1041,19 @@ async def get_physical_scenario_state(scenario_name: str): } """ # 1) Retrieve scenario metadata and node list from the DB - scenario = await get_scenario_by_name(scenario_name) + scenario = await get_scenario_by_federation_id(federation_id) if not scenario: raise HTTPException(status_code=404, detail="Scenario not found") - - nodes = await list_nodes_by_scenario_name(scenario_name) + + nodes = await list_nodes_by_federation_id_endpoint(federation_id) if not nodes: raise HTTPException(status_code=404, detail="No nodes found for scenario") - + # 2) Probe all nodes concurrently ips = [n["ip"] for n in nodes] tasks = [get_physical_node_state(ip) for ip in ips] states = await asyncio.gather(*tasks) # parallel HTTP calls - + # 3) Aggregate results nodes_state = dict(zip(ips, states)) any_running = any(s.get("running") for s in states) @@ -1007,7 +1062,7 @@ async def get_physical_scenario_state(scenario_name: str): all_available = all( (not s.get("running")) and (not s.get("error")) for s in states ) - + return { "running": any_running, "nodes_state": nodes_state, @@ -1015,7 +1070,7 @@ async def get_physical_scenario_state(scenario_name: str): } -@app.post("/user/add") +@app.post(controller_requests.Routes.USER_ADD) async def add_user_controller(user: str = Body(...), password: str = Body(...), role: str = Body(...)): """ Endpoint to add a new user to the database. @@ -1027,17 +1082,16 @@ async def add_user_controller(user: str = Body(...), password: str = Body(...), Returns a success message or an error if the user could not be added. """ - from nebula.controller.database import add_user - try: - add_user(user, password, role) - return {"detail": "User added successfully"} + payload = controller_requests.UserAddRequest(user=user, password=password, role=role).model_dump() + path = controller_requests.factory_requests_path("add_user") + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: logging.exception(f"Error adding user: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error adding user: {e}") -@app.post("/user/delete") +@app.post(controller_requests.Routes.USER_DELETE) async def remove_user_controller(user: str = Body(..., embed=True)): """ Controller endpoint that inserts a new user into the database. @@ -1047,17 +1101,16 @@ async def remove_user_controller(user: str = Body(..., embed=True)): Returns a success message if the user is deleted, or an HTTP error if an exception occurs. """ - from nebula.controller.database import delete_user_from_db - try: - delete_user_from_db(user) - return {"detail": "User deleted successfully"} + path = controller_requests.factory_requests_path("delete_user") + payload = controller_requests.UserDeleteRequest(user=user).model_dump() + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: logging.exception(f"Error deleting user: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error deleting user: {e}") -@app.post("/user/update") +@app.post(controller_requests.Routes.USER_UPDATE) async def update_user_controller(user: str = Body(...), password: str = Body(...), role: str = Body(...)): """ Controller endpoint that modifies a user of the database. @@ -1069,17 +1122,16 @@ async def update_user_controller(user: str = Body(...), password: str = Body(... Returns a success message if the user is updated, or an HTTP error if an exception occurs. """ - from nebula.controller.database import update_user - try: - update_user(user, password, role) - return {"detail": "User updated successfully"} + payload = controller_requests.UserUpdateRequest(user=user, password=password, role=role).model_dump() + path = controller_requests.factory_requests_path("update_user") + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) except Exception as e: logging.exception(f"Error updating user: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error updating user: {e}") -@app.post("/user/verify") +@app.post(controller_requests.Routes.USER_VERIFY) async def verify_user_controller(user: str = Body(...), password: str = Body(...)): """ Endpoint to verify user credentials. @@ -1090,16 +1142,13 @@ async def verify_user_controller(user: str = Body(...), password: str = Body(... Returns the user role on success or raises an error on failure. """ - from nebula.controller.database import get_user_info, list_users, verify - try: - user_submitted = user.upper() - if (user_submitted in list_users()) and verify(user_submitted, password): - user_info = get_user_info(user_submitted) - return {"user": user_submitted, "role": user_info[2]} - else: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - except Exception as e: + payload = controller_requests.UserVerifyRequest(user=user, password=password).model_dump() + path = controller_requests.factory_requests_path("verify_user") + return await APIUtils.post(f"{DATABASE_API_URL}{path}", data=payload) + except HTTPException as e: + if e.status_code == 401: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from e logging.exception(f"Error verifying user: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error verifying user: {e}") diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index acf1e98e5..64c5d84a2 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -199,30 +199,30 @@ def __init__( self.mobile_participants_percent = mobile_participants_percent self.additional_participants = additional_participants self.with_trustworthiness = with_trustworthiness - self.robustness_pillar = robustness_pillar, - self.resilience_to_attacks = resilience_to_attacks, - self.algorithm_robustness = algorithm_robustness, - self.client_reliability = client_reliability, - self.privacy_pillar = privacy_pillar, - self.technique = technique, - self.uncertainty = uncertainty, - self.indistinguishability = indistinguishability, - self.fairness_pillar = fairness_pillar, - self.selection_fairness = selection_fairness, - self.performance_fairness = performance_fairness, - self.class_distribution = class_distribution, - self.explainability_pillar = explainability_pillar, - self.interpretability = interpretability, - self.post_hoc_methods = post_hoc_methods, - self.accountability_pillar = accountability_pillar, - self.factsheet_completeness = factsheet_completeness, - self.architectural_soundness_pillar = architectural_soundness_pillar, - self.client_management = client_management, - self.optimization = optimization, - self.sustainability_pillar = sustainability_pillar, - self.energy_source = energy_source, - self.hardware_efficiency = hardware_efficiency, - self.federation_complexity = federation_complexity, + self.robustness_pillar = (robustness_pillar,) + self.resilience_to_attacks = (resilience_to_attacks,) + self.algorithm_robustness = (algorithm_robustness,) + self.client_reliability = (client_reliability,) + self.privacy_pillar = (privacy_pillar,) + self.technique = (technique,) + self.uncertainty = (uncertainty,) + self.indistinguishability = (indistinguishability,) + self.fairness_pillar = (fairness_pillar,) + self.selection_fairness = (selection_fairness,) + self.performance_fairness = (performance_fairness,) + self.class_distribution = (class_distribution,) + self.explainability_pillar = (explainability_pillar,) + self.interpretability = (interpretability,) + self.post_hoc_methods = (post_hoc_methods,) + self.accountability_pillar = (accountability_pillar,) + self.factsheet_completeness = (factsheet_completeness,) + self.architectural_soundness_pillar = (architectural_soundness_pillar,) + self.client_management = (client_management,) + self.optimization = (optimization,) + self.sustainability_pillar = (sustainability_pillar,) + self.energy_source = (energy_source,) + self.hardware_efficiency = (hardware_efficiency,) + self.federation_complexity = (federation_complexity,) self.schema_additional_participants = schema_additional_participants self.random_topology_probability = random_topology_probability self.with_sa = with_sa @@ -549,6 +549,46 @@ def from_dict(cls, data): return scenario + @staticmethod + def to_json(scenario_obj): + """ + Converts a Scenario object to a JSON string. + + Args: + scenario_obj (Scenario): An instance of the Scenario class. + + Returns: + str: A JSON string representation of the Scenario object. + """ + if not isinstance(scenario_obj, Scenario): + raise TypeError("Input must be an instance of the Scenario class.") + + # Get all attributes of the Scenario object + scenario_dict = scenario_obj.__dict__ + + # Convert the dictionary to a JSON string + return json.dumps(scenario_dict, indent=2) # Using indent for pretty-printing + + @staticmethod + def to_json(scenario_obj): + """ + Converts a Scenario object to a JSON string. + + Args: + scenario_obj (Scenario): An instance of the Scenario class. + + Returns: + str: A JSON string representation of the Scenario object. + """ + if not isinstance(scenario_obj, Scenario): + raise TypeError("Input must be an instance of the Scenario class.") + + # Get all attributes of the Scenario object + scenario_dict = scenario_obj.__dict__ + + # Convert the dictionary to a JSON string + return json.dumps(scenario_dict, indent=2) # Using indent for pretty-printing + # Class to manage the current scenario class ScenarioManagement: @@ -579,12 +619,16 @@ def __init__(self, scenario, user=None): self.scenario_name = f"nebula_{self.scenario.federation}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}" self.root_path = os.environ.get("NEBULA_ROOT_HOST") self.host_platform = os.environ.get("NEBULA_HOST_PLATFORM") - self.config_dir = os.path.join(os.environ.get("NEBULA_CONFIG_DIR"), self.scenario_name) - self.log_dir = os.environ.get("NEBULA_LOGS_DIR") - self.cert_dir = os.environ.get("NEBULA_CERTS_DIR") - self.advanced_analytics = os.environ.get("NEBULA_ADVANCED_ANALYTICS", "False") == "True" + self.config_dir = os.path.join(os.environ.get("NEBULA_CONFIG_DIR", "config"), self.scenario_name) + self.log_dir = os.environ.get("NEBULA_LOGS_DIR", "logs") + self.cert_dir = os.environ.get("NEBULA_CERTS_DIR", "certs") self.config = Config(entity="scenarioManagement") + # Tag-based naming for scenario resources + self.env_tag = os.environ.get("NEBULA_ENV_TAG", "dev") + self.prefix_tag = os.environ.get("NEBULA_PREFIX_TAG", "dev") + self.user_tag = os.environ.get("NEBULA_USER_TAG", os.environ.get("USER", "unknown")) + # If physical set the neighbours correctly if self.scenario.deployment == "physical" and self.scenario.physical_ips: for idx, ip in enumerate(self.scenario.physical_ips): @@ -761,6 +805,26 @@ def __init__(self, scenario, user=None): with open(participant_file, "w") as f: json.dump(participant_config, f, sort_keys=False, indent=2) + def get_network_name(self, suffix: str) -> str: + """ + Generate a standardized network name using tags. + Args: + suffix (str): Suffix for the network (default: 'net-base'). + Returns: + str: The composed network name. + """ + return f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_{suffix}" + + def get_participant_container_name(self, idx: int) -> str: + """ + Generate a standardized container name for a participant using tags. + Args: + idx (int): The participant index. + Returns: + str: The composed container name. + """ + return f"{self.env_tag}_{self.prefix_tag}_{self.user_tag}_{self.scenario_name}_participant{idx}" + @staticmethod def stop_participants(scenario_name=None): """ @@ -870,7 +934,6 @@ async def load_configurations_and_start_nodes( # Update participants configuration is_start_node = False config_participants = [] - # ap = len(additional_participants) if additional_participants else 0 additional_nodes = len(additional_participants) if additional_participants else 0 logging.info(f"######## nodes: {self.n_nodes} + additionals: {additional_nodes} ######") @@ -902,7 +965,7 @@ async def load_configurations_and_start_nodes( participant_config["mobility_args"]["latitude"] = self.scenario.latitude participant_config["mobility_args"]["longitude"] = self.scenario.longitude # If not, use the given coordinates in the frontend - participant_config["tracking_args"]["local_tracking"] = "advanced" if self.advanced_analytics else "basic" + participant_config["tracking_args"]["local_tracking"] = "default" participant_config["tracking_args"]["log_dir"] = self.log_dir participant_config["tracking_args"]["config_dir"] = self.config_dir @@ -950,10 +1013,11 @@ async def load_configurations_and_start_nodes( with open(additional_participant_file) as f: participant_config = json.load(f) + + logging.info(f"Configuration | additional nodes | participant: {self.n_nodes + i + 1}") last_ip = participant_config["network_args"]["ip"] - logging.info(f"Valores de la ultima ip: ({last_ip})") participant_config["scenario_args"]["n_nodes"] = self.n_nodes + additional_nodes # self.n_nodes + i + 1 participant_config["device_args"]["idx"] = last_participant_index + i participant_config["network_args"]["neighbors"] = "" @@ -1178,7 +1242,8 @@ def start_nodes_docker(self): logging.info("Starting nodes using Docker Compose...") logging.info(f"env path: {self.env_path}") - network_name = f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{str(self.user).lower()}-nebula-net-scenario" + network_name = self.get_network_name(f"{self.scenario_name}-net-scenario") + base_network_name = self.get_network_name("net-base") # Create the Docker network base = DockerUtils.create_docker_network(network_name) @@ -1188,9 +1253,10 @@ def start_nodes_docker(self): self.config.participants.sort(key=lambda x: x["device_args"]["idx"]) i = 2 container_ids = [] + container_names = [] # Track names for metadata for idx, node in enumerate(self.config.participants): image = "nebula-core" - name = f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_{self.user}-participant{node['device_args']['idx']}" + name = self.get_participant_container_name(node["device_args"]["idx"]) if node["device_args"]["accelerator"] == "gpu": environment = { @@ -1223,10 +1289,10 @@ def start_nodes_docker(self): ] networking_config = client.api.create_networking_config({ - f"{network_name}": client.api.create_endpoint_config( + network_name: client.api.create_endpoint_config( ipv4_address=f"{base}.{i}", ), - f"{os.environ.get('NEBULA_CONTROLLER_NAME')}_nebula-net-base": client.api.create_endpoint_config(), + base_network_name: client.api.create_endpoint_config(), }) node["tracking_args"]["log_dir"] = "/nebula/app/logs" @@ -1238,6 +1304,12 @@ def start_nodes_docker(self): node["security_args"]["cafile"] = "/nebula/app/certs/ca_cert.pem" node = json.loads(json.dumps(node).replace("192.168.50.", f"{base}.")) # TODO change this + try: + existing = client.containers.get(name) + logging.warning(f"Container {name} already exists. Deployment may fail or cause conflicts.") + except docker.errors.NotFound: + pass # No conflict, safe to proceed + # Write the config file in config directory with open(f"{self.config_dir}/participant_{node['device_args']['idx']}.json", "w") as f: json.dump(node, f, indent=4) @@ -1259,10 +1331,16 @@ def start_nodes_docker(self): try: client.api.start(container_id) container_ids.append(container_id) + container_names.append(name) except Exception as e: logging.exception(f"Starting participant {name} error: {e}") i += 1 + # Write scenario-level metadata for cleanup + scenario_metadata = {"containers": container_names, "network": network_name} + with open(os.path.join(self.config_dir, "scenario.metadata"), "w") as f: + json.dump(scenario_metadata, f, indent=2) + def start_nodes_process(self): """ Starts participant nodes as independent background processes on the host machine. @@ -1529,3 +1607,94 @@ def scenario_finished(self, timeout_seconds): return False time.sleep(5) + + @staticmethod + def cleanup_scenario_containers(): + """ + Remove all participant containers and the scenario network. + Reads ALL scenario.metadata and removes all listed containers and the network, then deletes the metadata file. + Also forcibly stops and removes any containers still attached to the network before removing it. + """ + import json + import logging + import os + + import docker + + # Try multiple possible config directory locations. This depends on where the user called the function from. + possible_config_dirs = [ + os.environ.get("NEBULA_CONFIG_DIR"), + "/nebula/app/config", + "./app/config", + os.path.join(os.getcwd(), "app", "config"), + os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "app", "config"), + ] + + config_dir = None + for dir_path in possible_config_dirs: + if dir_path and os.path.exists(dir_path): + config_dir = dir_path + break + + if not config_dir: + logging.warning("No valid config directory found, skipping cleanup") + return + + scenario_dirs = [] + logging.info(f"Config directory: {config_dir}") + if os.path.exists(config_dir): + for item in os.listdir(config_dir): + scenario_path = os.path.join(config_dir, item) + if os.path.isdir(scenario_path): + metadata_file = os.path.join(scenario_path, "scenario.metadata") + if os.path.exists(metadata_file): + scenario_dirs.append(scenario_path) + + logging.info(f"Removing scenario containers for {scenario_dirs}") + if not scenario_dirs: + logging.info("No active scenarios found to clean up") + return + + client = docker.from_env() + + for scenario_dir in scenario_dirs: + metadata_path = os.path.join(scenario_dir, "scenario.metadata") + if not os.path.exists(metadata_path): + logging.info(f"Skipping {scenario_dir} - no scenario.metadata found") + continue + + with open(metadata_path) as f: + meta = json.load(f) + + # Remove containers listed in metadata + for name in meta.get("containers", []): + try: + container = client.containers.get(name) + container.remove(force=True) + logging.info(f"Removed scenario container {name}") + except Exception as e: + logging.warning(f"Could not remove scenario container {name}: {e}") + + # Remove network, but first forcibly remove any containers still attached + network_name = meta.get("network") + if network_name: + try: + network = client.networks.get(network_name) + attached_containers = network.attrs.get("Containers") or {} + for container_id in attached_containers: + try: + c = client.containers.get(container_id) + c.remove(force=True) + logging.info(f"Force-removed container {c.name} attached to {network_name}") + except Exception as e: + logging.warning(f"Could not force-remove container {container_id}: {e}") + network.remove() + logging.info(f"Removed scenario network {network_name}") + except Exception as e: + logging.warning(f"Could not remove scenario network {network_name}: {e}") + + # Remove metadata file + try: + os.remove(metadata_path) + except Exception as e: + logging.warning(f"Could not remove scenario.metadata: {e}") diff --git a/nebula/controller/start_services.sh b/nebula/controller/start_services.sh index 98e53411c..d506eab54 100644 --- a/nebula/controller/start_services.sh +++ b/nebula/controller/start_services.sh @@ -14,10 +14,12 @@ NEBULA_SOCK=nebula.sock echo "NEBULA_PRODUCTION: $NEBULA_PRODUCTION" if [ "$NEBULA_PRODUCTION" = "False" ]; then echo "Starting Gunicorn in dev mode..." - uvicorn nebula.controller.controller:app --host 0.0.0.0 --port $NEBULA_CONTROLLER_PORT --log-level debug --proxy-headers --forwarded-allow-ips "*" & + uvicorn nebula.controller.hub:app --host 0.0.0.0 --port $NEBULA_CONTROLLER_PORT --log-level debug --proxy-headers --forwarded-allow-ips "*" & + uvicorn nebula.controller.federation.federation_api:app --host 0.0.0.0 --port $NEBULA_FEDERATION_CONTROLLER_PORT --log-level debug --proxy-headers --forwarded-allow-ips "*" & else echo "Starting Gunicorn in production mode..." - uvicorn nebula.controller.controller:app --host 0.0.0.0 --port $NEBULA_CONTROLLER_PORT --log-level info --proxy-headers --forwarded-allow-ips "*" & + uvicorn nebula.controller.hub:app --host 0.0.0.0 --port $NEBULA_CONTROLLER_PORT --log-level info --proxy-headers --forwarded-allow-ips "*" & + uvicorn nebula.controller.federation.federation_api:app --host 0.0.0.0 --port $NEBULA_FEDERATION_CONTROLLER_PORT --log-level debug --proxy-headers --forwarded-allow-ips "*" & fi tail -f /dev/null diff --git a/nebula/controller/utils_requests.py b/nebula/controller/utils_requests.py new file mode 100644 index 000000000..240168886 --- /dev/null +++ b/nebula/controller/utils_requests.py @@ -0,0 +1,209 @@ +from typing import Any, Dict, List + +from pydantic import BaseModel, conint, confloat + + +class Routes: + # General + INIT = "/" + STATUS = "/status" + RESOURCES = "/resources" + LEAST_MEMORY_GPU = "/least_memory_gpu" + AVAILABLE_GPUS = "/available_gpus/" + + # Scenarios (Controller + DB API routing) + RUN = "/scenarios/run" + UPDATE = "/scenarios/{federation_id}/update" + STOP = "/scenarios/{federation_id}/stop" + REMOVE = "/scenarios/{federation_id}/remove" + FINISH = "/scenarios/{federation_id}/set_status_to_finished" + RUNNING = "/scenarios/running" + CHECK_SCENARIO = "/scenarios/check/{user}/{role}/{federation_id}" + GET_SCENARIOS_BY_USER = "/scenarios/{user}/{role}" + GET_SCENARIO_BY_FEDERATION_ID = "/scenarios/{federation_id}" + + # Nodes + NODES_BY_FEDERATION_ID = "/nodes/{federation_id}" + NODES_UPDATE = "/nodes/update" + NODES_UPDATE_BY_FEDERATION = "/nodes/{federation_id}/update" + NODES_DONE_BY_SCENARIO = "/nodes/{scenario_name}/done" + NODES_REMOVE = "/nodes/{federation_id}/remove" + + # Notes + NOTES_BY_FEDERATION_ID = "/notes/{federation_id}" + NOTES_UPDATE = "/notes/{federation_id}/update" + NOTES_REMOVE = "/notes/{federation_id}/remove" + + # Users + USER_LIST = "/user/list" + USER_BY_FEDERATION_ID = "/user/{federation_id}" + USER_ADD = "/user/add" + USER_DELETE = "/user/delete" + USER_UPDATE = "/user/update" + USER_VERIFY = "/user/verify" + + # Discovery / Physical management + DISCOVER_VPN = "/discover-vpn" + PHYSICAL_RUN = "/physical/run" + PHYSICAL_STOP = "/physical/stop" + PHYSICAL_SETUP = "/physical/setup" + PHYSICAL_STATE = "/physical/state" + PHYSICAL_SCENARIO_STATE = "/physical/{federation_id}/state" + + +class RunScenarioRequest(BaseModel): + """Request model to trigger a scenario run on the controller. + + - Only requires scenario_data and user. + - Extra fields (e.g., role, federation_id) are ignored. + """ + scenario_data: Dict[str, Any] + user: str + + +class ScenarioUpdateRequest(BaseModel): + alias: str + scenario_name: str + start_time: str + end_time: str + scenario: Dict[str, Any] + status: str + username: str + + +class ScenarioStopRequest(BaseModel): + all: bool = False + + +class ScenarioRemoveRequest(BaseModel): + scenario_name: str + + +class ScenarioFinishRequest(BaseModel): + all: bool = False + + +class NotesUpdateRequest(BaseModel): + notes: str + + +class UserAddRequest(BaseModel): + user: str + password: str + role: str + + +class UserDeleteRequest(BaseModel): + user: str + + +class UserUpdateRequest(BaseModel): + user: str + password: str + role: str + + +class UserVerifyRequest(BaseModel): + user: str + password: str + + +# Nodes update payload +class DeviceArgs(BaseModel): + uid: str + idx: int + role: str + malicious: bool + + +class NetworkArgs(BaseModel): + ip: str + port: conint(ge=1, le=65535) # type: ignore[valid-type] + neighbors: List[Any] + + +class MobilityArgs(BaseModel): + latitude: confloat(ge=-90, le=90) # type: ignore[valid-type] + longitude: confloat(ge=-180, le=180) # type: ignore[valid-type] + + +class TrackingArgs(BaseModel): + run_hash: str + + +class FederationArgs(BaseModel): + round: int + + +class ScenarioArgs(BaseModel): + federation: str + name: str + + +class NodesUpdateRequest(BaseModel): + device_args: DeviceArgs + network_args: NetworkArgs + mobility_args: MobilityArgs + tracking_args: TrackingArgs + federation_args: FederationArgs + scenario_args: ScenarioArgs + timestamp: str + + +def factory_requests_path( + resource: str, + user: str = "", + role: str = "", + federation_id: str = "", +) -> str: + """Build paths for requests to the Database API from the Controller. + + This factory only maps DB API resources; controller endpoints do not require mapping here. + """ + if resource == "init": + return Routes.INIT + elif resource == "update": + return Routes.UPDATE.format(federation_id=federation_id) + elif resource == "stop": + return Routes.STOP.format(federation_id=federation_id) + elif resource == "remove": + return Routes.REMOVE.format(federation_id=federation_id) + elif resource == "finish": + return Routes.FINISH.format(federation_id=federation_id) + elif resource == "running": + return Routes.RUNNING + elif resource == "check_scenario": + return Routes.CHECK_SCENARIO.format(user=user, role=role, federation_id=federation_id) + elif resource == "get_scenarios_by_user": + return Routes.GET_SCENARIOS_BY_USER.format(user=user, role=role) + elif resource == "get_scenarios_by_scenario_name": + return Routes.GET_SCENARIO_BY_FEDERATION_ID.format(federation_id=federation_id) + # Nodes + elif resource == "get_nodes_by_scenario_name": + return Routes.NODES_BY_FEDERATION_ID.format(federation_id=federation_id) + elif resource == "update_nodes": + return Routes.NODES_UPDATE + elif resource == "remove_nodes": + return Routes.NODES_REMOVE.format(federation_id=federation_id) + # Notes + elif resource == "get_notes_by_scenario_name": + return Routes.NOTES_BY_FEDERATION_ID.format(federation_id=federation_id) + elif resource == "update_notes": + return Routes.NOTES_UPDATE.format(federation_id=federation_id) + elif resource == "remove_notes": + return Routes.NOTES_REMOVE.format(federation_id=federation_id) + # Users + elif resource == "list_users": + return Routes.USER_LIST + elif resource == "get_user_by_scenario_name": + return Routes.USER_BY_FEDERATION_ID.format(federation_id=federation_id) + elif resource == "add_user": + return Routes.USER_ADD + elif resource == "delete_user": + return Routes.USER_DELETE + elif resource == "update_user": + return Routes.USER_UPDATE + elif resource == "verify_user": + return Routes.USER_VERIFY + else: + raise Exception(f"resource not found: {resource}") diff --git a/nebula/core/addonmanager.py b/nebula/core/addonmanager.py index 46fbfd60a..f3bf3a505 100644 --- a/nebula/core/addonmanager.py +++ b/nebula/core/addonmanager.py @@ -1,5 +1,6 @@ import logging from typing import TYPE_CHECKING +from abc import ABC, abstractmethod from nebula.addons.functions import print_msg_box from nebula.addons.gps.gpsmodule import factory_gpsmodule @@ -10,6 +11,15 @@ if TYPE_CHECKING: from nebula.core.engine import Engine +class NebulaAddon(ABC): + @abstractmethod + async def start(): + raise NotImplementedError + + @abstractmethod + async def stop(): + raise NotImplementedError + class AddondManager: """ @@ -51,24 +61,43 @@ async def deploy_additional_services(self): - Services are only launched if the corresponding configuration flags are set. """ print_msg_box(msg="Deploying Additional Services", indent=2, title="Addons Manager") - if self._config.participant["trustworthiness"]: - from nebula.addons.trustworthiness.trustworthiness import Trustworthiness - - trustworthiness = Trustworthiness(self._engine, self._config) - self._addons.append(trustworthiness) - - if self._config.participant["mobility_args"]["mobility"]: - mobility = Mobility(self._config, verbose=False) - self._addons.append(mobility) - - update_interval = 5 - gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=False) - self._addons.append(gps) - - if self._config.participant["network_args"]["simulation"]: - refresh_conditions_interval = 5 - network_simulation = factory_network_simulator("nebula", refresh_conditions_interval, "eth0", verbose=False) - self._addons.append(network_simulation) + for addon, addon_config in self._config.participant["addons"].items(): + if addon == "mobility": + mobility = Mobility(self._config, verbose=False) + self._addons.append(mobility) + update_interval = 5 + gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=False) + self._addons.append(gps) + elif addon == "trustworthiness": + from nebula.addons.trustworthiness.trustworthiness import Trustworthiness + trustworthiness = Trustworthiness(self._engine, self._config) + self._addons.append(trustworthiness) + elif addon == "network_simulation": + #TODO review parameters for network simulation + type_of_network = self._config.participant["network_args"]["network_simulation"]["type"] + network_config = self._config.participant["network_args"]["network_simulation"]["network_config"] + network_simulation = factory_network_simulator(type_of_network, network_config) + self._addons.append(network_simulation) + #TODO update config access + + # if self._config.participant["trustworthiness"]: + # from nebula.addons.trustworthiness.trustworthiness import Trustworthiness + + # trustworthiness = Trustworthiness(self._engine, self._config) + # self._addons.append(trustworthiness) + + # if self._config.participant["mobility_args"]["mobility"]: + # mobility = Mobility(self._config, verbose=False) + # self._addons.append(mobility) + + # update_interval = 5 + # gps = factory_gpsmodule("nebula", self._config, self._engine.addr, update_interval, verbose=False) + # self._addons.append(gps) + + # if self._config.participant["network_args"]["simulation"]: + # refresh_conditions_interval = 5 + # network_simulation = factory_network_simulator("nebula", refresh_conditions_interval, "eth0", verbose=False) + # self._addons.append(network_simulation) for add in self._addons: await add.start() diff --git a/nebula/core/datasets/nebuladataset.py b/nebula/core/datasets/nebuladataset.py index 0c2e03d8a..a3d1cbe92 100755 --- a/nebula/core/datasets/nebuladataset.py +++ b/nebula/core/datasets/nebuladataset.py @@ -1299,3 +1299,17 @@ def factory_nebuladataset(dataset, **config) -> NebulaDataset: if not cs: raise ValueError(f"Dataset {dataset} not supported") return cs(**config) + +def factory_dataset_setup(dataset) -> dict: + options = { + "MNIST": 10, + "FashionMNIST": 10, + "EMNIST": 47, + "CIFAR10": 10, + "CIFAR100": 100, + } + + num_classes = options.get(dataset, None) + if not num_classes: + raise ValueError(f"Dataset {dataset} not supported") + return num_classes diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 151ae6b22..6195f6926 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -94,7 +94,7 @@ def __init__( self.ip = config.participant["network_args"]["ip"] self.port = config.participant["network_args"]["port"] self.addr = config.participant["network_args"]["addr"] - + self.name = config.participant["device_args"]["name"] self.client = docker.from_env() @@ -115,7 +115,6 @@ def __init__( self._aggregator = create_aggregator(config=self.config, engine=self) self._secure_neighbors = [] - self._is_malicious = self.config.participant["adversarial_args"]["attack_params"]["attacks"] != "No Attack" role = config.participant["device_args"]["role"] self._role_behavior: RoleBehavior = factory_role_behavior(role, self, config) @@ -132,7 +131,6 @@ def __init__( msg += f"\nIID: {self.config.participant['data_args']['iid']}" msg += f"\nModel: {model.__class__.__name__}" msg += f"\nAggregation algorithm: {self._aggregator.__class__.__name__}" - msg += f"\nNode behavior: {'malicious' if self._is_malicious else 'benign'}" print_msg_box(msg=msg, indent=2, title="Scenario information") print_msg_box( msg=f"Logging type: {self._trainer.logger.__class__.__name__}", @@ -160,12 +158,12 @@ def __init__( self._addon_manager = AddondManager(self, self.config) # Additional Components - if "situational_awareness" in self.config.participant: + if "situational_awareness" in self.config.participant["addons"]: self._situational_awareness = SituationalAwareness(self.config, self) else: self._situational_awareness = None - if self.config.participant["defense_args"]["reputation"]["enabled"]: + if dict(self.config.participant["addons"]).get("reputation", None): self._reputation = Reputation(engine=self, config=self.config) @property @@ -187,7 +185,7 @@ def aggregator(self): def trainer(self): """Trainer""" return self._trainer - + @property def rb(self): """Role Behavior""" @@ -317,7 +315,7 @@ async def _control_alive_callback(self, source, message): async def _control_leadership_transfer_callback(self, source, message): logging.info(f"🔧 handle_control_message | Trigger | Received leadership transfer message from {source}") - + if await self._round_in_process_lock.locked_async(): logging.info("Learning cycle is executing, role behavior will be modified next round") await self.rb.set_next_role(Role.AGGREGATOR, source_to_notificate=source) @@ -354,7 +352,7 @@ async def _control_leadership_transfer_ack_callback(self, source, message): except TimeoutError: logging.info("Learning cycle is locked, role behavior will be modified next round") await self.rb.set_next_role(Role.TRAINER) - + async def _connection_connect_callback(self, source, message): logging.info(f"🔗 handle_connection_message | Trigger | Received connection message from {source}") @@ -600,7 +598,7 @@ async def start_communications(self): before other services or training processes begin. """ await self.register_events_callbacks() - initial_neighbors = self.config.participant["network_args"]["neighbors"].split() + initial_neighbors = self.config.participant["network_args"]["neighbors"] await self.cm.start_communications(initial_neighbors) await asyncio.sleep(self.config.participant["misc_args"]["grace_time_connection"] // 2) @@ -619,10 +617,10 @@ async def deploy_components(self): the federated learning process starts. """ await self.aggregator.init() - if "situational_awareness" in self.config.participant: - await self.sa.init() - if self.config.participant["defense_args"]["reputation"]["enabled"]: - await self._reputation.setup() + if "situational_awareness" in self.config.participant["addons"]: + await self.sa.start() + if "reputation" in self.config.participant["addons"]: + await self._reputation.start() await self._reporter.start() await self._addon_manager.deploy_additional_services() @@ -710,10 +708,10 @@ async def _start_learning(self): await self.get_federation_ready_lock().acquire_async() if self.config.participant["device_args"]["start"]: logging.info("Propagate initial model updates.") - + mpe = ModelPropagationEvent(await self.cm.get_addrs_current_connections(only_direct=True, myself=False), "initialization") await EventManager.get_instance().publish_node_event(mpe) - + await self.get_federation_ready_lock().release_async() self.trainer.set_epochs(epochs) @@ -764,7 +762,7 @@ async def learning_cycle_finished(self): return False else: return current_round >= self.total_rounds - + async def resolve_missing_updates(self): """ Delegates the resolution strategy for missing updates to the current role behavior. @@ -778,7 +776,7 @@ async def resolve_missing_updates(self): """ logging.info(f"Using Role behavior: {self.rb.get_role_name()} conflict resolve strategy") return await self.rb.resolve_missing_updates() - + async def update_self_role(self): """ Checks whether a role update is required and performs the transition if necessary. @@ -806,7 +804,7 @@ async def update_self_role(self): logging.info(f"Sending role modification ACK to transferer: {source_to_notificate}") message = self.cm.create_message("control", "leadership_transfer_ack") asyncio.create_task(self.cm.send_message(source_to_notificate, message)) - + async def _learning_cycle(self): """ Main asynchronous loop for executing the Federated Learning process across multiple rounds. @@ -837,9 +835,9 @@ async def _learning_cycle(self): indent=2, title="Round information", ) - + await self.update_self_role() - + logging.info(f"Federation nodes: {self.federation_nodes}") await self.update_federation_nodes( await self.cm.get_addrs_current_connections(only_direct=True, myself=True) @@ -851,10 +849,10 @@ async def _learning_cycle(self): logging.info(f"Expected nodes: {expected_nodes}") direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) undirected_connections = await self.cm.get_addrs_current_connections(only_undirected=True) - + logging.info(f"Direct connections: {direct_connections} | Undirected connections: {undirected_connections}") logging.info(f"[Role {self.rb.get_role_name()}] Starting learning cycle...") - + await self.aggregator.update_federation_nodes(expected_nodes) async with self._role_behavior_performance_lock: await self.rb.extended_learning_cycle() @@ -882,13 +880,13 @@ async def _learning_cycle(self): self.trainer.on_learning_cycle_end() await self.trainer.test() - + # Shutdown protocol await self._shutdown_protocol() - + async def _shutdown_protocol(self): logging.info("Starting graceful shutdown process...") - + # 1.- Publish Experiment Finish Event to the last update on modules logging.info("Publishing Experiment Finish Event...") efe = ExperimentFinishEvent() diff --git a/nebula/core/network/communications.py b/nebula/core/network/communications.py index e0b1c17a5..5b022b61e 100755 --- a/nebula/core/network/communications.py +++ b/nebula/core/network/communications.py @@ -88,7 +88,8 @@ def __init__(self, engine: "Engine"): ) self.receive_messages_lock = Locker(name="receive_messages_lock", async_lock=True) - self._discoverer = Discoverer(addr=self.addr, config=self.config) + #self._discoverer = Discoverer(addr=self.addr, config=self.config) + self._discoverer = None # self._health = Health(addr=self.addr, config=self.config) self._health = None self._forwarder = Forwarder(config=self.config) diff --git a/nebula/core/node.py b/nebula/core/node.py index 86a73cc2a..a8df622c2 100755 --- a/nebula/core/node.py +++ b/nebula/core/node.py @@ -42,7 +42,6 @@ from nebula.core.models.mnist.mlp import MNISTModelMLP from nebula.core.engine import Engine from nebula.core.training.lightning import Lightning -from nebula.core.training.siamese import Siamese # os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # os.environ["TORCH_LOGS"] = "+dynamo" @@ -70,7 +69,6 @@ async def main(config: Config): Raises: ValueError: If an unsupported model, dataset, or device role is specified. - NotImplementedError: If an unsupported training strategy (e.g., "scikit") is requested. Returns: Coroutine that initializes and starts the NEBULA node. @@ -79,7 +77,7 @@ async def main(config: Config): model_name = config.participant["model_args"]["model"] idx = config.participant["device_args"]["idx"] - additional_node_status = config.participant["mobility_args"]["additional_node"]["status"] + additional_node_status = config.participant["deployment_args"]["additional"] # Adjust the total number of nodes and the index of the current node for CFL, as it doesn't require a specific partition for the server (not used for training) if config.participant["scenario_args"]["federation"] == "CFL": @@ -169,10 +167,6 @@ async def main(config: Config): trainer_str = config.participant["training_args"]["trainer"] if trainer_str == "lightning": trainer = Lightning - elif trainer_str == "scikit": - raise NotImplementedError - elif trainer_str == "siamese": - trainer = Siamese else: raise ValueError(f"Trainer {trainer_str} not supported") @@ -185,11 +179,6 @@ def randomize_value(value, variability): config_keys = [ ["reporter_args", "report_frequency"], - ["discoverer_args", "discovery_frequency"], - ["health_args", "health_interval"], - ["health_args", "grace_time_health"], - ["health_args", "check_alive_interval"], - ["health_args", "send_alive_interval"], ["forwarder_args", "forwarder_interval"], ["forwarder_args", "forward_messages_interval"], ] @@ -215,9 +204,10 @@ def randomize_value(value, variability): await node.deploy_federation() if additional_node_status: - time = config.participant["mobility_args"]["additional_node"]["time_start"] - logging.info(f"Waiting time to start finding federation: {time}") - await asyncio.sleep(int(config.participant["mobility_args"]["additional_node"]["time_start"])) + # time = config.participant["addons"]["mobility"]["additional_node"]["time_start"] + # logging.info(f"Waiting time to start finding federation: {time}") + # await asyncio.sleep(int(config.participant["addons"]["mobility"]["additional_node"]["time_start"])) + #await asyncio.sleep(120) #TODO REMOVE await node._aditional_node_start() if node.cm is not None: diff --git a/nebula/core/noderole.py b/nebula/core/noderole.py index 9bd258fef..d1b227e6e 100644 --- a/nebula/core/noderole.py +++ b/nebula/core/noderole.py @@ -190,7 +190,7 @@ def __init__(self, engine: Engine, config: Config): self.attack = create_attack(self._engine) logging.info("Attack behavior created") self.aggregator_bening = self._engine._aggregator - benign_role = self._config.participant["adversarial_args"]["fake_behavior"] + benign_role = self._config.participant["addons"]["adversarial_args"]["fake_behavior"] self._fake_role_behavior = factory_role_behavior(benign_role, self._engine, self._config) self._role = factory_node_role("malicious") @@ -206,7 +206,7 @@ async def extended_learning_cycle(self): try: await self.attack.attack() except Exception: - attack_name = self._config.participant["adversarial_args"]["attacks"] + attack_name = self._config.participant["addons"]["adversarial_args"]["attacks"] logging.exception(f"Attack {attack_name} failed") await self._fake_role_behavior.extended_learning_cycle() diff --git a/nebula/core/situationalawareness/awareness/sareasoner.py b/nebula/core/situationalawareness/awareness/sareasoner.py index 40e6de94d..b7f76e25c 100644 --- a/nebula/core/situationalawareness/awareness/sareasoner.py +++ b/nebula/core/situationalawareness/awareness/sareasoner.py @@ -85,9 +85,10 @@ def __init__( title="SA Reasoner", ) logging.info("🌐 Initializing SAReasoner") - self._config = copy.deepcopy(config.participant) + self._is_additional_node = config.participant["deployment_args"]["additional"] + self._config = copy.deepcopy(config.participant["addons"]) self._addr = config.participant["network_args"]["addr"] - self._topology = config.participant["mobility_args"]["topology_type"] + self._topology = config.participant["addons"]["mobility"]["topology_type"] self._situational_awareness_network: SANetwork | None = None self._situational_awareness_training = None self._restructure_process_lock = Locker(name="restructure_process_lock", async_lock=True) @@ -96,11 +97,11 @@ def __init__( self._suggestion_buffer = SuggestionBuffer(self._arbitrator_notification, verbose=True) self._communciation_manager = CommunicationsManager.get_instance() self._sys_monitor = SystemMonitor() - arb_pol = config.participant["situational_awareness"]["sa_reasoner"]["arbitration_policy"] + arb_pol = config.participant["addons"]["situational_awareness"]["sa_reasoner"]["arbitration_policy"] self._arbitatrion_policy = factory_arbitration_policy(arb_pol, True) self._sa_components: dict[str, SAMComponent] = {} self._sa_discovery: ISADiscovery | None = None - self._verbose = config.participant["situational_awareness"]["sa_reasoner"]["verbose"] + self._verbose = config.participant["addons"]["situational_awareness"]["sa_reasoner"]["verbose"] @property def san(self) -> SANetwork | None: @@ -146,7 +147,8 @@ def is_additional_participant(self): Returns: bool: True if the node is marked as an additional participant, False otherwise. """ - return self._config["mobility_args"]["additional_node"]["status"] + return self._is_additional_node + """ ############################### # REESTRUCTURE TOPOLOGY # diff --git a/nebula/core/situationalawareness/awareness/suggestionbuffer.py b/nebula/core/situationalawareness/awareness/suggestionbuffer.py index 98cae49b2..c1d354133 100644 --- a/nebula/core/situationalawareness/awareness/suggestionbuffer.py +++ b/nebula/core/situationalawareness/awareness/suggestionbuffer.py @@ -5,7 +5,7 @@ from nebula.core.situationalawareness.awareness.sautils.sacommand import SACommand from nebula.core.situationalawareness.awareness.sautils.samoduleagent import SAModuleAgent from nebula.core.utils.locker import Locker -from nebula.utils import logging +import logging class SuggestionBuffer: diff --git a/nebula/core/situationalawareness/situationalawareness.py b/nebula/core/situationalawareness/situationalawareness.py index 6a5dbcbd6..2fee2d7de 100644 --- a/nebula/core/situationalawareness/situationalawareness.py +++ b/nebula/core/situationalawareness/situationalawareness.py @@ -1,6 +1,6 @@ import asyncio from abc import ABC, abstractmethod - +from nebula.core.addonmanager import NebulaAddon from nebula.addons.functions import print_msg_box @@ -156,7 +156,7 @@ def factory_sa_reasoner(sa_reasoner, config) -> ISAReasoner: raise Exception(f"SA Reasoner service {sa_reasoner} not found.") -class SituationalAwareness: +class SituationalAwareness(NebulaAddon): """ High-level coordinator for Situational Awareness in the DFL federation. @@ -178,16 +178,16 @@ def __init__(self, config, engine): title="Situational Awareness module", ) self._config = config - selector = self._config.participant["situational_awareness"]["sa_discovery"]["candidate_selector"] + selector = self._config.participant["addons"]["situational_awareness"]["sa_discovery"]["candidate_selector"] selector = selector.lower() - model_handler = config.participant["situational_awareness"]["sa_discovery"]["model_handler"] + model_handler = config.participant["addons"]["situational_awareness"]["sa_discovery"]["model_handler"] self._sad = factory_sa_discovery( "nebula", - self._config.participant["mobility_args"]["additional_node"]["status"], + self._config.participant["deployment_args"]["additional"], selector, model_handler, engine=engine, - verbose=config.participant["situational_awareness"]["sa_discovery"]["verbose"], + verbose=config.participant["addons"]["situational_awareness"]["sa_discovery"]["verbose"], ) self._sareasoner = factory_sa_reasoner( "nebula", @@ -214,7 +214,7 @@ def sar(self): """ return self._sareasoner - async def init(self): + async def start(self): """ Initialize both discovery and reasoner components, linking them together. """ diff --git a/nebula/core/training/lightning.py b/nebula/core/training/lightning.py index d83975147..5c68c91c3 100755 --- a/nebula/core/training/lightning.py +++ b/nebula/core/training/lightning.py @@ -133,7 +133,7 @@ def __init__(self, model, datamodule, config=None): self.round = 0 self.experiment_name = self.config.participant["scenario_args"]["name"] self.idx = self.config.participant["device_args"]["idx"] - self.log_dir = os.path.join(self.config.participant["tracking_args"]["log_dir"], self.experiment_name) + self.log_dir = self.config.participant["tracking_args"]["log_dir"] self._logger = None self.create_logger() enable_deterministic(seed=self.config.participant["scenario_args"]["random_seed"]) @@ -152,9 +152,7 @@ def set_datamodule(self, datamodule): self.datamodule = datamodule def create_logger(self): - if self.config.participant["tracking_args"]["local_tracking"] == "csv": - nebulalogger = CSVLogger(f"{self.log_dir}", name="metrics", version=f"participant_{self.idx}") - elif self.config.participant["tracking_args"]["local_tracking"] == "basic": + if self.config.participant["tracking_args"]["local_tracking"] == "default": logger_config = None if self._logger is not None: logger_config = self._logger.get_logger_config() @@ -167,6 +165,8 @@ def create_logger(self): ) # Restore logger configuration nebulalogger.set_logger_config(logger_config) + elif self.config.participant["tracking_args"]["local_tracking"] == "csv": + nebulalogger = CSVLogger(f"{self.log_dir}", name="metrics", version=f"participant_{self.idx}") else: nebulalogger = None diff --git a/nebula/core/training/scikit.py b/nebula/core/training/scikit.py deleted file mode 100755 index 99cb02b9d..000000000 --- a/nebula/core/training/scikit.py +++ /dev/null @@ -1,79 +0,0 @@ -import logging -import pickle -import traceback - -from sklearn.metrics import accuracy_score - - -class Scikit: - def __init__(self, model, data, config=None, logger=None): - self.model = model - self.data = data - self.config = config - self.logger = logger - self.round = 0 - self.epochs = 1 - self.logger.log_data({"Round": self.round}, step=self.logger.global_step) - - def set_model(self, model): - self.model = model - - def get_round(self): - return self.round - - def set_data(self, data): - self.data = data - - def serialize_model(self, params=None): - if params is None: - params = self.model.get_params() - return pickle.dumps(params) - - def deserialize_model(self, data): - try: - params = pickle.loads(data) - return params - except: - raise Exception("Error decoding parameters") - - def set_model_parameters(self, params): - self.model.set_params(**params) - - def get_model_parameters(self): - return self.model.get_params() - - def set_epochs(self, epochs): - self.epochs = epochs - - def fit(self): - try: - X_train, y_train = self.data.train_dataloader() - self.model.fit(X_train, y_train) - except Exception as e: - logging.exception(f"Error with scikit-learn fit. {e}") - logging.exception(traceback.format_exc()) - - def interrupt_fit(self): - pass - - def evaluate(self): - try: - X_test, y_test = self.data.test_dataloader() - y_pred = self.model.predict(X_test) - accuracy = accuracy_score(y_test, y_pred) - logging.info(f"Accuracy: {accuracy}") - except Exception as e: - logging.exception(f"Error with scikit-learn evaluate. {e}") - logging.exception(traceback.format_exc()) - return None - - def get_train_size(self): - return ( - len(self.data.train_dataloader()), - len(self.data.test_dataloader()), - ) - - def finalize_round(self): - self.round += 1 - if self.logger: - self.logger.log_data({"Round": self.round}) diff --git a/nebula/core/training/siamese.py b/nebula/core/training/siamese.py deleted file mode 100755 index 21999c3a5..000000000 --- a/nebula/core/training/siamese.py +++ /dev/null @@ -1,193 +0,0 @@ -import hashlib -import io -import logging -import traceback -from collections import OrderedDict - -import torch -from lightning import Trainer -from lightning.pytorch.callbacks import RichModelSummary, RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme - -from nebula.core.utils.deterministic import enable_deterministic - - -class Siamese: - def __init__(self, model, data, config=None, logger=None): - # self.model = torch.compile(model, mode="reduce-overhead") - self.model = model - self.data = data - self.config = config - self.logger = logger - self.__trainer = None - self.epochs = 1 - logging.getLogger("lightning.pytorch").setLevel(logging.INFO) - self.round = 0 - enable_deterministic(seed=self.config.participant["scenario_args"]["random_seed"]) - self.logger.log_data({"Round": self.round}, step=self.logger.global_step) - - @property - def logger(self): - return self._logger - - def get_round(self): - return self.round - - def set_model(self, model): - self.model = model - - def set_data(self, data): - self.data = data - - def create_trainer(self): - logging.info( - "[Trainer] Creating trainer with accelerator: {}".format( - self.config.participant["device_args"]["accelerator"] - ) - ) - progress_bar = RichProgressBar( - theme=RichProgressBarTheme( - description="green_yellow", - progress_bar="green1", - progress_bar_finished="green1", - progress_bar_pulse="#6206E0", - batch_progress="green_yellow", - time="grey82", - processing_speed="grey82", - metrics="grey82", - ), - leave=True, - ) - if self.config.participant["device_args"]["accelerator"] == "gpu": - # NEBULA uses 2 GPUs (max) to distribute the nodes. - if self.config.participant["device_args"]["devices"] > 1: - # If you have more than 2 GPUs, you should specify which ones to use. - gpu_id = ([1] if self.config.participant["device_args"]["idx"] % 2 == 0 else [2],) - else: - # If there is only one GPU, it will be used. - gpu_id = [1] - - self.__trainer = Trainer( - callbacks=[RichModelSummary(max_depth=1), progress_bar], - max_epochs=self.epochs, - accelerator=self.config.participant["device_args"]["accelerator"], - devices=gpu_id, - logger=self.logger, - log_every_n_steps=50, - enable_checkpointing=False, - enable_model_summary=False, - enable_progress_bar=True, - # deterministic=True - ) - else: - # NEBULA uses only CPU to distribute the nodes - self.__trainer = Trainer( - callbacks=[RichModelSummary(max_depth=1), progress_bar], - max_epochs=self.epochs, - accelerator=self.config.participant["device_args"]["accelerator"], - devices="auto", - logger=self.logger, - log_every_n_steps=50, - enable_checkpointing=False, - enable_model_summary=False, - enable_progress_bar=True, - # deterministic=True - ) - - def get_global_model_parameters(self): - return self.model.get_global_model_parameters() - - def set_parameter_second_aggregation(self, params): - try: - logging.info("Setting parameters in second aggregation...") - self.model.load_state_dict(params) - except: - raise Exception("Error setting parameters") - - def get_model_parameters(self, bytes=False): - if bytes: - return self.serialize_model(self.model.state_dict()) - else: - return self.model.state_dict() - - def get_hash_model(self): - """ - Returns: - str: SHA256 hash of model parameters - """ - return hashlib.sha256(self.serialize_model()).hexdigest() - - def set_epochs(self, epochs): - self.epochs = epochs - - #### - # Model parameters serialization/deserialization - # From https://pytorch.org/docs/stable/notes/serialization.html - #### - def serialize_model(self, model): - try: - buffer = io.BytesIO() - # with gzip.GzipFile(fileobj=buffer, mode='wb') as f: - # torch.save(params, f) - torch.save(model, buffer) - return buffer.getvalue() - except: - raise Exception("Error serializing model") - - def deserialize_model(self, data): - try: - buffer = io.BytesIO(data) - # with gzip.GzipFile(fileobj=buffer, mode='rb') as f: - # params_dict = torch.load(f, map_location='cpu') - params_dict = torch.load(buffer, map_location="cpu") - return OrderedDict(params_dict) - except: - raise Exception("Error decoding parameters") - - def set_model_parameters(self, params, initialize=False): - try: - if initialize: - self.model.load_state_dict(params) - self.model.global_load_state_dict(params) - self.model.historical_load_state_dict(params) - else: - # First aggregation - self.model.global_load_state_dict(params) - except: - raise Exception("Error setting parameters") - - def train(self): - try: - self.create_trainer() - # torch.autograd.set_detect_anomaly(True) - # TODO: It is necessary to train only the local model, save the history of the previous model and then load it, the global model is the aggregation of all the models. - self.__trainer.fit(self.model, self.data) - # Save local model as historical model (previous round) - # It will be compared the next round during training local model (constrantive loss) - # When aggregation in global model (first) and aggregation with similarities and weights (second), the historical model keeps inmutable - logging.info("Saving historical model...") - self.model.save_historical_model() - except Exception as e: - logging.exception(f"Error training model: {e}") - logging.exception(traceback.format_exc()) - - def test(self): - try: - self.create_trainer() - self.__trainer.test(self.model, self.data, verbose=True) - except Exception as e: - logging.exception(f"Error testing model: {e}") - logging.exception(traceback.format_exc()) - - def get_model_weight(self): - return ( - len(self.data.train_dataloader().dataset), - len(self.data.test_dataloader().dataset), - ) - - def finalize_round(self): - self.logger.global_step = self.logger.global_step + self.logger.local_step - self.logger.local_step = 0 - self.round += 1 - self.logger.log_data({"Round": self.round}, step=self.logger.global_step) - pass diff --git a/nebula/database/__init__.py b/nebula/database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/database/adapters/__init__.py b/nebula/database/adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/database/adapters/postgress/__init__.py b/nebula/database/adapters/postgress/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nebula/database/adapters/postgress/docker/Dockerfile b/nebula/database/adapters/postgress/docker/Dockerfile new file mode 100644 index 000000000..2859de8da --- /dev/null +++ b/nebula/database/adapters/postgress/docker/Dockerfile @@ -0,0 +1,53 @@ +FROM postgres:17.5-alpine3.22 + +# Rename the official entrypoint so we can wrap it +RUN mv /usr/local/bin/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh.orig + +# Copy SQL init file and custom entrypoint script +COPY ./nebula/database/adapters/postgress/docker/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh +RUN chmod +x /usr/local/bin/docker-entrypoint.sh + +# Install Python 3.11.7 from source +RUN apk add --no-cache \ + gcc \ + g++ \ + musl-dev \ + make \ + openssl-dev \ + bzip2-dev \ + zlib-dev \ + xz-dev \ + readline-dev \ + sqlite-dev \ + libffi-dev \ + curl \ + tar \ + bash + +ENV PYTHON_VERSION=3.11.7 + +RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \ + tar -xzf Python-${PYTHON_VERSION}.tgz && \ + cd Python-${PYTHON_VERSION} && \ + ./configure --prefix=/usr/local --enable-optimizations && \ + make -j$(nproc) && \ + make install && \ + cd .. && rm -rf Python-${PYTHON_VERSION}* + +RUN python3.11 --version + +# Install uv (alternative to pip, very fast) +ADD https://astral.sh/uv/install.sh /uv-installer.sh +RUN sh /uv-installer.sh && rm /uv-installer.sh +ENV PATH="/root/.local/bin/:$PATH" + +COPY pyproject.toml . + +# Install Python dependencies using uv +RUN uv python pin 3.11.7 +RUN uv sync --group database + +ENV PATH="/.venv/bin:$PATH" + +ENTRYPOINT ["/bin/bash", "/usr/local/bin/docker-entrypoint.sh"] +CMD ["postgres"] diff --git a/nebula/database/adapters/postgress/docker/docker-entrypoint.sh b/nebula/database/adapters/postgress/docker/docker-entrypoint.sh new file mode 100644 index 000000000..911961294 --- /dev/null +++ b/nebula/database/adapters/postgress/docker/docker-entrypoint.sh @@ -0,0 +1,21 @@ +#!/bin/sh +set -x + +# Start the python API in the background +echo "🐍 Starting Nebula Database API in the background..." +( + # Wait for postgres to be ready + until pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" -h localhost >/dev/null 2>&1; do + sleep 1 + done + echo "✅ PostgreSQL is ready, starting API." + + cd nebula + NEBULA_SOCK=nebula.sock + + uvicorn nebula.database.database_api:app --host 0.0.0.0 --port 5051 --log-level debug --proxy-headers --forwarded-allow-ips "*" +) & + +# Run the original postgres entrypoint in the foreground +# This will become the main process of the container +exec /usr/local/bin/docker-entrypoint.sh.orig "$@" diff --git a/nebula/database/adapters/postgress/docker/init-configs.sql b/nebula/database/adapters/postgress/docker/init-configs.sql new file mode 100644 index 000000000..7119931ee --- /dev/null +++ b/nebula/database/adapters/postgress/docker/init-configs.sql @@ -0,0 +1,80 @@ +-- -------------------------------------------------- +-- init_postgres.sql +-- -------------------------------------------------- + +-- 1) (Optional) If you need to create the database, uncomment: +-- CREATE DATABASE nebula; +-- \c nebula + +-- 2) Users table +CREATE TABLE IF NOT EXISTS users ( + "user" TEXT PRIMARY KEY, + password TEXT, + role TEXT +); + +-- 2) Nodes +CREATE TABLE IF NOT EXISTS nodes ( + uid TEXT PRIMARY KEY, + idx TEXT, + ip TEXT, + port TEXT, + role TEXT, + neighbors TEXT[], + timestamp TEXT, + federation TEXT, + round TEXT, + scenario TEXT, + hash TEXT, + extras JSONB, + malicious TEXT +); + +-- Ensure column exists for pre-existing installations +ALTER TABLE IF EXISTS nodes + ADD COLUMN IF NOT EXISTS extras JSONB; + +-- Drop legacy columns for latitude/longitude if present +-- ALTER TABLE IF EXISTS nodes +-- DROP COLUMN IF EXISTS latitude; +-- ALTER TABLE IF EXISTS nodes +-- DROP COLUMN IF EXISTS longitude; +-- AlTER TABLE IF EXISTS scenarios +-- ADD COLUMN IF NOT EXISTS federation_id TEXT; +-- ALTER TABLE IF EXISTS scenarios +-- DROP CONSTRAINT scenarios_pkey; +-- ALTER TABLE IF EXISTS scenarios +-- ADD CONSTRAINT scenarios_pkey PRIMARY KEY (federation_id); + +-- 3) Configs as JSONB +DROP INDEX IF EXISTS idx_configs_config_gin; +DROP TABLE IF EXISTS configs; +CREATE TABLE configs ( + id SERIAL PRIMARY KEY, + config JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX idx_configs_config_gin ON configs USING GIN (config); + +-- 4) Scenarios table as JSONB +CREATE TABLE IF NOT EXISTS scenarios ( + federation_id TEXT PRIMARY KEY, + alias TEXT NOT NULL, + name TEXT NOT NULL, + username TEXT NOT NULL, + status TEXT, + start_time TEXT, + end_time TEXT, + config JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Index for fast JSONB queries on scenarios.config +CREATE INDEX IF NOT EXISTS idx_scenarios_config_gin + ON scenarios USING GIN (config); + +-- 5) Notes table +CREATE TABLE IF NOT EXISTS notes ( + federation_id TEXT PRIMARY KEY, + scenario_notes TEXT +); diff --git a/nebula/database/adapters/postgress/postgress.py b/nebula/database/adapters/postgress/postgress.py new file mode 100755 index 000000000..ec6210516 --- /dev/null +++ b/nebula/database/adapters/postgress/postgress.py @@ -0,0 +1,750 @@ +import logging +import os +import datetime +import json +import asyncpg +import asyncio + +from passlib.context import CryptContext + +from nebula.database.database_adapter_interface import DatabaseAdapter + +# --- Configuration --- +# Use environment variables for database credentials from the Docker Compose file +DATABASE_URL = f"postgresql://{os.environ.get('DB_USER')}:{os.environ.get('DB_PASSWORD')}@{os.environ.get('DB_HOST')}:{os.environ.get('DB_PORT')}/nebula" + +# Password hashing context (using Argon2) +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") + +# Asynchronous lock for node updates +_node_lock = asyncio.Lock() + + +class PostgresDB(DatabaseAdapter): + """ + PostgreSQL implementation of the Database interface. + """ + def __init__(self): + self.pool = None + + async def _init_db_pool(self): + """ + Initializes the asynchronous PostgreSQL connection pool. + This should be called once when the application starts. + Retries connection on failure to handle race conditions during startup. + """ + if self.pool is None: + attempts = 10 + for attempt in range(attempts): + try: + self.pool = await asyncpg.create_pool( + dsn=DATABASE_URL, + min_size=5, # Minimum number of connections in the pool + max_size=20, # Maximum number of connections in the pool + ) + logging.info("Database connection pool successfully created.") + return + except (ConnectionRefusedError, asyncpg.exceptions.CannotConnectNowError) as e: + if attempt < attempts - 1: + logging.warning( + f"Database connection failed. Attempt {attempt + 1}/{attempts}. Retrying in 5 seconds... " + f"Error: {e}" + ) + await asyncio.sleep(5) + else: + logging.critical( + f"Failed to create database connection pool after {attempts} attempts: {e}", exc_info=True + ) + raise + except Exception as e: + logging.critical( + f"An unexpected error occurred while creating database connection pool: {e}", exc_info=True + ) + raise + + async def _close_db_pool(self): + """ + Closes the asynchronous PostgreSQL connection pool. + This should be called once when the application shuts down gracefully. + """ + if self.pool: + await self.pool.close() + logging.info("Database connection pool closed.") + + # --- User Management Functions --- + + async def _insert_default_admin(self): + """ + Inserts a default 'ADMIN' user into the database with a hashed password. + The password must be provided via the ADMIN_PASSWORD environment variable. + """ + admin_password = os.environ.get("NEBULA_ADMIN_PASSWORD") + + hashed_password = pwd_context.hash(admin_password) + + query = """ + INSERT INTO users ("user", password, role) + VALUES ($1, $2, $3) + ON CONFLICT ("user") DO NOTHING; + """ + try: + async with self.pool.acquire() as conn: + await conn.execute(query, "ADMIN", hashed_password, "admin") + logging.info("Default admin user inserted (or already exists).") + except Exception as e: + logging.error(f"Failed to insert default admin user: {e}", exc_info=True) + + async def _list_users(self, all_info: bool = False): + """ + Retrieves a list of users from the users database. + """ + async with self.pool.acquire() as conn: + result = await conn.fetch("SELECT * FROM users") + + if all_info: + # Return JSON-serializable dicts with full info + return [dict(row) for row in result] + else: + # Return just the list of usernames (strings) + return [row["user"] for row in result] + + + async def _get_user_info(self, user: str): + """ + Fetches detailed information for a specific user from the users database. + """ + async with self.pool.acquire() as conn: + return await conn.fetchrow('SELECT * FROM users WHERE "user" = $1', user) + + + async def _verify(self, user: str, password: str): + """ + Verifies credentials and returns user info when valid. + + Returns + ------- + dict | None + {"user": USER, "role": ROLE} if valid, otherwise None. + """ + user_up = user.upper() + async with self.pool.acquire() as conn: + row = await conn.fetchrow('SELECT password, role FROM users WHERE "user" = $1', user_up) + if not row: + return None + try: + if pwd_context.verify(password, row["password"]): + return {"user": user_up, "role": row["role"]} + except Exception: + logging.error(f"Error during password verification for user {user_up}", exc_info=True) + return None + + + async def _verify_hash_algorithm(self, user: str): + """ + Checks if the stored password hash for a user uses a supported Argon2 algorithm. + """ + user = user.upper() + argon2_prefixes = ("$argon2i$", "$argon2id$") + async with self.pool.acquire() as conn: + result = await conn.fetchrow('SELECT password FROM users WHERE "user" = $1', user) + if result: + password_hash = result["password"] + return password_hash.startswith(argon2_prefixes) + return False + + + async def _delete_user_from_db(self, user: str): + """ + Deletes a user record from the users database. + """ + async with self.pool.acquire() as conn: + await conn.execute('DELETE FROM users WHERE "user" = $1', user) + + + async def _add_user(self, user:str, password:str, role:str): + """ + Adds a new user to the users database with a hashed password. + """ + hashed_password = pwd_context.hash(password) + async with self.pool.acquire() as conn: + await conn.execute( + 'INSERT INTO users ("user", password, role) VALUES ($1, $2, $3)', + user.upper(), hashed_password, role, + ) + + + async def _update_user(self, user:str, password:str, role:str): + """ + Updates the password and role of an existing user in the users database. + """ + hashed_password = pwd_context.hash(password) + async with self.pool.acquire() as conn: + await conn.execute( + 'UPDATE users SET password = $1, role = $2 WHERE "user" = $3', + hashed_password, role, user.upper(), + ) + + # --- Node Management Functions --- + + async def _list_nodes(self, federation_id:str=None, sort_by:str="idx"): + """ + Retrieves a list of nodes from the nodes database, optionally filtered by scenario and sorted. + """ + # Validate sort_by to prevent SQL injection + allowed_sort_fields = ["uid", "idx", "ip", "port", "role", "timestamp", "federation", "round"] + if sort_by not in allowed_sort_fields: + sort_by = "idx" # Default to a safe field + + try: + async with self.pool.acquire() as conn: + if federation_id: + # Using f-string for column names is generally safe if validated as above + command = f"SELECT * FROM nodes WHERE federation = $1 ORDER BY {sort_by};" + result = await conn.fetch(command, federation_id) + else: + command = f"SELECT * FROM nodes ORDER BY {sort_by};" + result = await conn.fetch(command) + + # Convert to list of dicts and expose latitude/longitude from extras for compatibility + rows = [] + for record in result: + row = dict(record) + extras = row.get("extras") + if isinstance(extras, str): + try: + extras = json.loads(extras) + except json.JSONDecodeError: + extras = None + if isinstance(extras, dict): + if "latitude" in extras and "latitude" not in row: + row["latitude"] = extras.get("latitude") + if "longitude" in extras and "longitude" not in row: + row["longitude"] = extras.get("longitude") + rows.append(row) + return rows + except asyncpg.PostgresError as e: + logging.error(f"Error occurred while listing nodes: {e}") + return None + + + async def _list_nodes_by_federation_id(self, federation_id:str): + """ + Fetches all nodes associated with a specific scenario, ordered by their index as integers. + """ + try: + async with self.pool.acquire() as conn: + command = "SELECT * FROM nodes WHERE federation = $1 ORDER BY CAST(idx AS INTEGER) ASC;" + result = await conn.fetch(command, federation_id) + rows = [] + for record in result: + row = dict(record) + extras = row.get("extras") + if isinstance(extras, str): + try: + extras = json.loads(extras) + except json.JSONDecodeError: + extras = None + if isinstance(extras, dict): + if "latitude" in extras and "latitude" not in row: + row["latitude"] = extras.get("latitude") + if "longitude" in extras and "longitude" not in row: + row["longitude"] = extras.get("longitude") + rows.append(row) + return rows + except Exception as e: + logging.error(f"Error occurred while listing nodes by scenario name: {e}") + return None + + + async def _update_node_record( + self, + node_uid, + idx, + ip, + port, + role, + neighbors, + extras, + timestamp, + federation, + federation_round, + scenario, + run_hash, + malicious, + ): + """ + Inserts or updates a node record in the database for a given scenario, ensuring thread-safe access. + """ + async with _node_lock: + async with self.pool.acquire() as conn: + try: + # Ensure `extras` is a JSON string when provided + extras_payload = None + if extras is not None: + if isinstance(extras, str): + extras_payload = extras + else: + try: + extras_payload = json.dumps(extras) + except (TypeError, ValueError): + # Fallback to empty JSON object on serialization issues + logging.warning("Unable to serialize extras to JSON, storing as empty object.") + extras_payload = json.dumps({}) + + # Ensure malicious is stored as text if the column expects text + malicious_payload = malicious if isinstance(malicious, str) else str(malicious) + + async with conn.transaction(): + result = await conn.fetchrow( + "SELECT * FROM nodes WHERE uid = $1 AND scenario = $2 FOR UPDATE;", + node_uid, scenario + ) + + if result is None: + # Insert new node + await conn.execute( + """ + INSERT INTO nodes (uid, idx, ip, port, role, neighbors, + timestamp, federation, round, scenario, hash, extras, malicious) + VALUES ($1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12::jsonb, $13); + """, + node_uid, idx, ip, port, role, neighbors, + timestamp, federation, federation_round, scenario, run_hash, extras_payload, malicious_payload, + ) + else: + # Update existing node + await conn.execute( + """ + UPDATE nodes SET idx = $1, ip = $2, port = $3, role = $4, neighbors = $5, + timestamp = $6, federation = $7, round = $8, + hash = $9, extras = $10::jsonb, malicious = $11 + WHERE uid = $12 AND scenario = $13; + """, + idx, ip, port, role, neighbors, + timestamp, federation, federation_round, + run_hash, extras_payload, malicious_payload, + node_uid, scenario, + ) + + updated_row = await conn.fetchrow("SELECT * from nodes WHERE uid = $1 AND scenario = $2;", node_uid, scenario) + return dict(updated_row) if updated_row else None + except asyncpg.PostgresError as e: + logging.error(f"Database error during node record update: {e}", exc_info=True) + return None + + + async def _remove_all_nodes(self): + """ + Deletes all node records from the nodes database. + """ + async with self.pool.acquire() as conn: + await conn.execute("TRUNCATE nodes CASCADE;") # Use CASCADE if there are foreign key dependencies + + + async def _remove_nodes_by_federation_id(self, federation_id:str): + """ + Deletes all nodes associated with a specific scenario from the database. + """ + async with self.pool.acquire() as conn: + await conn.execute("DELETE FROM nodes WHERE federation = $1;", federation_id) + + # --- Scenario Management Functions --- + + async def _get_all_scenarios(self, username:str, role:str, sort_by:str="start_time"): + """ + Retrieves all scenarios from the database, accessing fields from the 'config' (JSONB) column + and direct columns. Filters by user role and sorts by the specified field. + """ + allowed_sort_fields = ["start_time", "title", "username", "status", "name"] + if sort_by not in allowed_sort_fields: + sort_by = "start_time" + + # Determine the ORDER BY clause based on sort_by + if sort_by == "start_time": + order_by_clause = """ + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC + """ + elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB + order_by_clause = f"ORDER BY config->>'{sort_by}'" + else: # For direct table columns like name, username, status + order_by_clause = f"ORDER BY {sort_by}" + + async with self.pool.acquire() as conn: + # Select direct columns and relevant fields from config JSONB + command = """ + SELECT + federation_id, + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- return the full config JSONB + FROM scenarios + """ + params = [] + + if role != "admin": + command += " WHERE username = $1" # username is a direct column now + params.append(username) + + full_command = f"{command} {order_by_clause};" + return await conn.fetch(full_command, *params) + + + async def _get_all_scenarios_and_check_completed(self, user:str, role:str, sort_by:str="start_time"): + """ + Retrieves all scenarios, sorts them, and updates the status if necessary. + Returns a list of dictionaries, where each dictionary represents a scenario. + """ + # Safe list of allowed sorting fields to prevent SQL injection. + allowed_sort_fields = ["start_time", "title", "username", "status", "name"] + if sort_by not in allowed_sort_fields: + sort_by = "start_time" # Safe default value + + # Building the ORDER BY clause + if sort_by == "start_time": + order_by_clause = """ + ORDER BY + CASE + WHEN start_time IS NULL OR start_time = '' THEN 1 + ELSE 0 + END, + to_timestamp(start_time, 'DD/MM/YYYY HH24:MI:SS') DESC + """ + elif sort_by in ["title", "model", "dataset", "rounds"]: # These are inside config JSONB + order_by_clause = f"ORDER BY config->>'{sort_by}'" + else: # For direct table columns like name, username, status + order_by_clause = f"ORDER BY {sort_by}" + + async with self.pool.acquire() as conn: + # Base query that extracts fields from the JSONB using the ->> operator + command = f""" + SELECT + federation_id, + name, + username, + status, + start_time, + end_time, + config->>'title' AS title, + config->>'model' AS model, + config->>'dataset' AS dataset, + config->>'rounds' AS rounds, + config -- Return the full config object + FROM scenarios + """ + params = [] + if role != "admin": + command += " WHERE username = $1" # username is a direct column + params.append(user) + + command += f" {order_by_clause};" + + result_dicts = await conn.fetch(command, *params) + + scenarios_to_return = [dict(s) for s in result_dicts] + + re_fetch_required = False + for scenario in scenarios_to_return: + if scenario["status"] == "running": + if await self._check_scenario_federation_completed(scenario["federation_id"]): + await self._scenario_set_status_to_completed(scenario["federation_id"]) + re_fetch_required = True + break + + if re_fetch_required: + # Recursively call to get fresh data after status update + return await self._get_all_scenarios_and_check_completed(user, role, sort_by) + + return scenarios_to_return + + + async def _scenario_update_record(self, federation_id:str, alias:str, scenario_name:str, start_time:datetime, end_time:datetime, scenario:dict, status:str, username:str): + """ + Inserts or updates a scenario record using the PostgreSQL "UPSERT" pattern. + All configuration is saved in the 'config' column of type JSONB. + Direct columns (name, start_time, end_time, username, status) are also handled. + """ + # Ensure scenario is a dictionary before dumping to JSON + if not isinstance(scenario, dict): + try: + scenario = json.loads(scenario) + except (json.JSONDecodeError, TypeError): + logging.error("scenario is not a valid JSON string or dict.") + return + + command = """ + INSERT INTO scenarios (federation_id, alias, name, start_time, end_time, username, status, config) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb) + ON CONFLICT (federation_id) DO UPDATE SET + alias = EXCLUDED.alias, + name = EXCLUDED.name, + start_time = EXCLUDED.start_time, + end_time = EXCLUDED.end_time, + username = EXCLUDED.username, + status = EXCLUDED.status, + config = scenarios.config || EXCLUDED.config; -- Merge JSONB + """ + async with self.pool.acquire() as conn: + await conn.execute(command, federation_id, alias, scenario_name, start_time, end_time, username, status, json.dumps(scenario)) + + + async def _scenario_set_all_status_to_finished(self): + """ + Sets the status of all 'running' scenarios to 'finished' + and updates their 'end_time' (both in the direct column and within JSONB). + """ + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set(config, '{status}', '"finished"') || + jsonb_set(config, '{end_time}', $2::jsonb) + WHERE status = 'running'; + """ + async with self.pool.acquire() as conn: + await conn.execute(command, current_time, json.dumps(current_time)) + + + async def _scenario_set_status_to_finished(self, federation_id:str): + """ + Sets the status of a specific scenario to 'finished' and updates its 'end_time'. + Updates both the direct columns and the JSONB 'config'. + """ + current_time = datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') # Consistent format + command = """ + UPDATE scenarios + SET + status = 'finished', + end_time = $1, + config = jsonb_set( + jsonb_set(config, '{status}', '"finished"'), + '{end_time}', $2::jsonb + ) + WHERE federation_id = $3; + """ + async with self.pool.acquire() as conn: + await conn.execute(command, current_time, json.dumps(current_time), federation_id) + + + async def _scenario_set_status_to_completed(self, federation_id:str): + """ + Sets the status of a specific scenario to 'completed'. + Updates both the direct column and the JSONB 'config'. + """ + command = """ + UPDATE scenarios + SET + status = 'completed', + config = jsonb_set(config, '{status}', '"completed"') + WHERE federation_id = $1; + """ + async with self.pool.acquire() as conn: + await conn.execute(command, federation_id) + + + async def _finish_scenario(self, federation_id: str, all: bool = False): + """ + Consolidated method to set scenarios to finished. + """ + if all: + await self._scenario_set_all_status_to_finished() + else: + await self._scenario_set_status_to_finished(federation_id) + + + async def _get_running_scenario(self, username:str=None, get_all:bool=False): + """ + Retrieves scenarios with a 'running' status, optionally filtered by user. + Returns full scenario record (including direct columns and config JSONB). + """ + async with self.pool.acquire() as conn: + params = ["running"] + # Select all columns to get both direct and config data + command = "SELECT federation_id, name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1" + + if username: + command += " AND username = $2" + params.append(username) + + if get_all: + result = [dict(row) for row in await conn.fetch(command, *params)] # Convert records to dicts + else: + result_row = await conn.fetchrow(command, *params) + result = dict(result_row) if result_row else None + return result + + + async def _get_completed_scenario(self): + """ + Retrieves a single scenario with a 'completed' status. + Returns full scenario record (including direct columns and config JSONB). + """ + async with self.pool.acquire() as conn: + command = "SELECT name, username, status, start_time, end_time, config FROM scenarios WHERE status = $1;" + result_row = await conn.fetchrow(command, "completed") + return dict(result_row) if result_row else None + + async def _get_scenarios(self, user: str, role: str): + """ + Compose scenarios list and running scenario respecting role. + """ + scenarios = await self._get_all_scenarios_and_check_completed(user=user, role=role) + scenario_running = await self._get_running_scenario(None if role == "admin" else user) + return {"scenarios": scenarios, "scenario_running": scenario_running} + + + async def _get_scenario_by_federation_id(self, federation_id:str): + """ + Retrieves the complete record of a scenario by its name. + """ + async with self.pool.acquire() as conn: + result_row = await conn.fetchrow("SELECT name, start_time, end_time, username, status, config FROM scenarios WHERE federation_id = $1;", federation_id) + + result = dict(result_row) if result_row else None + + if result and result.get('config'): + # Assuming 'config' is a JSON string from the DB, so we parse it + # It might already be a dict if asyncpg handles JSONB conversion automatically + config_data = result['config'] + if isinstance(config_data, str): + try: + config_data = json.loads(config_data) + except json.JSONDecodeError: + config_data = {} + + # Extract the 'scenario_title' and add it as a top-level key + result['title'] = config_data.get('scenario_title') + result['description'] = config_data.get('description') + + return result + + + async def _get_user_by_federation_id(self, federation_id:str): + """ + Retrieves the username associated with a scenario (from the direct 'username' column). + """ + async with self.pool.acquire() as conn: + return await conn.fetchval("SELECT username FROM scenarios WHERE federation_id = $1;", federation_id) + + + async def _remove_scenario_by_federation_id(self, federation_id:str): + """ + Delete a scenario from the database by its unique name. + """ + try: + async with self.pool.acquire() as conn: + await conn.execute("DELETE FROM scenarios WHERE federation_id = $1;", federation_id) + logging.info(f"Scenario '{federation_id}' successfully removed.") + except asyncpg.PostgresError as e: + logging.error(f"Error occurred while deleting scenario '{federation_id}': {e}") + + + async def _check_scenario_federation_completed(self, federation_id:str): + """ + Check if all nodes in a given scenario have completed the required federation rounds. + """ + try: + async with self.pool.acquire() as conn: + # Retrieve the total rounds for the scenario from the 'config' JSONB column + scenario_rounds_str = await conn.fetchval("SELECT config->>'rounds' AS rounds FROM scenarios WHERE federation_id = $1;", federation_id) + + if not scenario_rounds_str: + logging.warning(f"Scenario '{federation_id}' not found or 'rounds' not defined.") + return False + + # Ensure total_rounds is an integer for comparison + try: + total_rounds = int(scenario_rounds_str) + except (ValueError, TypeError): + logging.error(f"Invalid 'rounds' value for scenario '{federation_id}': {scenario_rounds_str}") + return False + + # Fetch the current round progress of all nodes in that scenario + nodes = await conn.fetch("SELECT round FROM nodes WHERE federation = $1;", federation_id) + + if not nodes: + logging.info(f"No nodes found for federation '{federation_id}'. Federation not considered completed.") + return False + + # Check if all nodes have completed the total rounds + return all(int(node["round"]) >= total_rounds for node in nodes) + + except asyncpg.PostgresError as e: + logging.error(f"PostgreSQL error during check_scenario_federation_completed for '{federation_id}': {e}") + return False + except ValueError as e: + logging.error(f"Data error during check_scenario_federation_completed for '{federation_id}': {e}") + return False + + + async def _check_scenario_with_role(self, role:str, federation_id:str, user:str=None): + """ + Verify if a scenario exists that the user with the given role and username can access. + """ + scenario_info = await self._get_scenario_by_federation_id(federation_id) + + if not scenario_info: + return False # Scenario does not exist + + if role == "admin": + return True # Admins can access any existing scenario + + if user is None: + logging.warning( + "check_scenario_with_role called for non-admin role without user." + ) + return False + + return scenario_info.get("username") == user + + # --- Notes Management Functions --- + + async def _save_notes(self, federation_id: str, notes: str): + """ + Save or update notes associated with a specific scenario. + """ + try: + async with self.pool.acquire() as conn: + await conn.execute( + """ + INSERT INTO notes (federation_id, scenario_notes) VALUES ($1, $2) + ON CONFLICT(federation_id) DO UPDATE SET scenario_notes = EXCLUDED.scenario_notes; + """, + federation_id, notes, + ) + except asyncpg.PostgresError as e: + logging.error(f"PostgreSQL error during save_notes: {e}") + + + async def _get_notes(self, federation_id: str): + """ + Retrieve notes associated with a specific scenario. + """ + async with self.pool.acquire() as conn: + row = await conn.fetchrow("SELECT * FROM notes WHERE federation_id = $1;", federation_id) + if row is None: + # No notes stored for this scenario yet + return None + return dict(row) + + + async def _remove_note(self, federation_id: str): + """ + Delete the note associated with a specific scenario. + """ + async with self.pool.acquire() as conn: + await conn.execute("DELETE FROM notes WHERE federation_id = $1;", federation_id) diff --git a/nebula/database/database_adapter_factory.py b/nebula/database/database_adapter_factory.py new file mode 100644 index 000000000..ddbe443c1 --- /dev/null +++ b/nebula/database/database_adapter_factory.py @@ -0,0 +1,17 @@ +from nebula.database.adapters.postgress.postgress import PostgresDB +from nebula.database.database_adapter_interface import DatabaseAdapter + +class DatabaseAdapterException(Exception): + pass + +def factory_database_adapter(database_adapter: str) -> DatabaseAdapter: + + ADAPTERS = { + "PostgresDB": PostgresDB + } + + db_adapter = ADAPTERS.get(database_adapter, None) + if db_adapter: + return db_adapter() + else: + raise DatabaseAdapterException(f"Database Adapter \"{database_adapter}\" not supported") diff --git a/nebula/database/database_adapter_interface.py b/nebula/database/database_adapter_interface.py new file mode 100644 index 000000000..49ff82d7b --- /dev/null +++ b/nebula/database/database_adapter_interface.py @@ -0,0 +1,198 @@ +from abc import ABC, abstractmethod + + +class DatabaseAdapter(ABC): + """ + Abstract base class for database operations. + Defines a common interface for interacting with different database systems. + """ + + @abstractmethod + async def _init_db_pool(self): + """Initializes the database connection pool.""" + raise NotImplementedError + + @abstractmethod + async def _close_db_pool(self): + """Closes the database connection pool.""" + raise NotImplementedError + + # --- User Management Functions --- + + @abstractmethod + async def _insert_default_admin(self): + """Inserts a default admin user.""" + raise NotImplementedError + + @abstractmethod + async def _list_users(self, all_info=False): + """Retrieves a list of users.""" + raise NotImplementedError + + @abstractmethod + async def _get_user_info(self, user): + """Fetches detailed information for a specific user.""" + raise NotImplementedError + + @abstractmethod + async def _verify(self, user, password): + """Verifies user credentials.""" + raise NotImplementedError + + @abstractmethod + async def _verify_hash_algorithm(self, user): + """Checks the password hash algorithm for a user.""" + raise NotImplementedError + + @abstractmethod + async def _delete_user_from_db(self, user): + """Deletes a user from the database.""" + raise NotImplementedError + + @abstractmethod + async def _add_user(self, user, password, role): + """Adds a new user.""" + raise NotImplementedError + + @abstractmethod + async def _update_user(self, user, password, role): + """Updates an existing user.""" + raise NotImplementedError + + # --- Node Management Functions --- + + @abstractmethod + async def _list_nodes(self, federation_id=None, sort_by="idx"): + """Retrieves a list of nodes.""" + raise NotImplementedError + + @abstractmethod + async def _list_nodes_by_federation_id(self, federation_id): + """Fetches all nodes for a specific federation.""" + raise NotImplementedError + + @abstractmethod + async def _update_node_record( + self, + node_uid, + idx, + ip, + port, + role, + neighbors, + extras, + timestamp, + federation, + federation_round, + scenario, + run_hash, + malicious, + ): + """Inserts or updates a node record. Latitude/longitude must be included in `extras` (JSON).""" + raise NotImplementedError + + @abstractmethod + async def _remove_all_nodes(self): + """Deletes all node records.""" + raise NotImplementedError + + @abstractmethod + async def _remove_nodes_by_federation_id(self, federation_id): + """Deletes all nodes for a specific federation.""" + raise NotImplementedError + + # --- Scenario Management Functions --- + + @abstractmethod + async def _get_all_scenarios(self, username, role, sort_by="start_time"): + """Retrieves all scenarios.""" + raise NotImplementedError + + @abstractmethod + async def _get_all_scenarios_and_check_completed(self, username, role, sort_by="start_time"): + """Retrieves all scenarios and checks for completion.""" + raise NotImplementedError + + @abstractmethod + async def _scenario_update_record(self, federation_id, name, start_time, end_time, scenario_config, status, username): + """Inserts or updates a scenario record.""" + raise NotImplementedError + + @abstractmethod + async def _scenario_set_all_status_to_finished(self): + """Sets the status of all running scenarios to 'finished'.""" + raise NotImplementedError + + @abstractmethod + async def _scenario_set_status_to_finished(self, federation_id): + """Sets the status of a specific scenario (by federation_id) to 'finished'.""" + raise NotImplementedError + + @abstractmethod + async def _scenario_set_status_to_completed(self, federation_id): + """Sets the status of a specific scenario (by federation_id) to 'completed'.""" + raise NotImplementedError + + @abstractmethod + async def _get_running_scenario(self, username=None, get_all=False): + """Retrieves running scenarios.""" + raise NotImplementedError + + @abstractmethod + async def _get_completed_scenario(self): + """Retrieves a completed scenario.""" + raise NotImplementedError + + @abstractmethod + async def _get_scenario_by_federation_id(self, federation_id): + """Retrieves a scenario by its federation_id.""" + raise NotImplementedError + + @abstractmethod + async def _get_user_by_federation_id(self, federation_id): + """Retrieves the user associated with a scenario by federation_id.""" + raise NotImplementedError + + @abstractmethod + async def _remove_scenario_by_federation_id(self, federation_id): + """Deletes a scenario by its federation_id.""" + raise NotImplementedError + + @abstractmethod + async def _check_scenario_federation_completed(self, federation_id): + """Checks if a scenario's federation is complete.""" + raise NotImplementedError + + @abstractmethod + async def _check_scenario_with_role(self, role, federation_id, user=None): + """Verifies if a user can access a scenario by federation_id.""" + raise NotImplementedError + + # --- Notes Management Functions --- + + @abstractmethod + async def _save_notes(self, scenario, notes): + """Saves or updates notes for a scenario.""" + raise NotImplementedError + + @abstractmethod + async def _get_notes(self, scenario): + """Retrieves notes for a scenario.""" + raise NotImplementedError + + @abstractmethod + async def _remove_note(self, scenario): + """Deletes the note for a scenario.""" + raise NotImplementedError + + # --- Scenario Finish (no API logic) --- + + @abstractmethod + async def _finish_scenario(self, federation_id, all: bool = False): + """Sets status to finished for one scenario (by federation_id) or all running scenarios.""" + raise NotImplementedError + + @abstractmethod + async def _get_scenarios(self, user: str, role: str): + """Return scenarios list and running scenario, given user and role.""" + raise NotImplementedError diff --git a/nebula/database/database_api.py b/nebula/database/database_api.py new file mode 100644 index 000000000..bd74a7df7 --- /dev/null +++ b/nebula/database/database_api.py @@ -0,0 +1,347 @@ + +import logging +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +from fastapi import FastAPI, HTTPException, status, Depends +from fastapi.concurrency import asynccontextmanager + +from nebula.database.database_adapter_factory import factory_database_adapter +from nebula.database.utils_requests import ( + Routes, + ScenarioUpdateRequest, + ScenarioStopRequest, + ScenarioFinishRequest, + NotesUpdateRequest, + UserAddRequest, + UserDeleteRequest, + UserUpdateRequest, + UserVerifyRequest, + NodesUpdateRequest, + GetScenariosRequest, + GetRunningScenarioRequest, + CheckScenarioRequest, + ListUsersRequest, +) + +# Get a database instance +db = factory_database_adapter("PostgresDB") + + +# Setup logger +def configure_logger(log_file): + """ + Configures the logging system for the database API. + """ + log_console_format = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s" + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(logging.Formatter(log_console_format)) + file_handler = logging.FileHandler(log_file, mode="w") + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) + logging.basicConfig( + level=logging.DEBUG, + handlers=[ + console_handler, + file_handler, + ], + ) + uvicorn_loggers = ["uvicorn", "uvicorn.error", "uvicorn.access"] + for logger_name in uvicorn_loggers: + logger = logging.getLogger(logger_name) + logger.handlers = [] + logger.propagate = False + handler = logging.FileHandler(log_file, mode="a") + handler.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")) + logger.addHandler(handler) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Application lifespan context manager for the database API. + """ + # Code to run on startup + db_log = os.environ.get("NEBULA_DATABASE_LOG", "database.log") + configure_logger(db_log) + + # Initialize the database connection pool + await db._init_db_pool() + await db._insert_default_admin() + + yield + + # Code to run on shutdown + await db._close_db_pool() + + +app = FastAPI(lifespan=lifespan) + + +@app.get(Routes.INIT) +async def read_root(): + return {"message": "Welcome to the NEBULA Database API"} + + +# Scenarios +@app.post(Routes.UPDATE) +async def update_scenario( + federation_id: str, + request: ScenarioUpdateRequest, +): + try: + await db._scenario_update_record( + federation_id = federation_id, + **request.model_dump() + ) + return {"message": f"Scenario {request.scenario_name} updated successfully"} + except Exception as e: + logging.exception( + f"Error updating scenario {request.scenario_name}: {e}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.STOP) +async def stop_scenario( + federation_id: str, + request: ScenarioStopRequest, +): + try: + await db._finish_scenario(federation_id, request.all) + return {"message": "Finished status set successfully"} + except Exception as e: + logging.exception( + f"Error stopping scenario {federation_id}: {e}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.REMOVE) +async def remove_scenario( + federation_id: str +): + try: + await db._remove_scenario_by_federation_id(federation_id) + return {"message": f"Scenario {federation_id} removed successfully"} + except Exception as e: + logging.exception( + f"Error removing scenario {federation_id}: {e}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get(Routes.GET_SCENARIOS_BY_USER) +async def get_scenarios( + request: GetScenariosRequest = Depends() +): + try: + return await db._get_scenarios(request.user, request.role) + except Exception as e: + logging.exception(f"Error obtaining scenarios: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.FINISH) +async def set_scenario_status_to_finished( + federation_id: str, + request: ScenarioFinishRequest, +): + try: + await db._finish_scenario( + federation_id, request.all + ) + return {"message": "Finished status set successfully"} + except Exception as e: + logging.exception( + f"Error setting scenario {federation_id} to finished: {e}" + ) + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get(Routes.RUNNING) +async def get_running_scenario_endpoint(request: GetRunningScenarioRequest = Depends()): + try: + return await db._get_running_scenario(get_all=request.get_all) + except Exception as e: + logging.exception(f"Error obtaining running scenario: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get(Routes.CHECK_SCENARIO) +async def check_scenario( + request: CheckScenarioRequest = Depends() +): + try: + allowed = await db._check_scenario_with_role(**request.model_dump()) + return {"allowed": allowed} + except Exception as e: + logging.exception(f"Error checking scenario with role: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.get(Routes.GET_SCENARIOS_BY_SCENARIO_NAME) +async def get_scenario_by_name_endpoint( + federation_id: str +): + try: + scenario = await db._get_scenario_by_federation_id(federation_id) + return scenario + except Exception as e: + logging.exception(f"Error obtaining scenario {federation_id}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +# Nodes +@app.get(Routes.NODES_BY_FEDERATION_ID) +async def list_nodes_by_federation_id_endpoint( + federation_id: str +): + try: + nodes = await db._list_nodes_by_federation_id(federation_id) + return nodes + except Exception as e: + logging.exception(f"Error obtaining nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.NODES_UPDATE) +async def update_node_record(request: NodesUpdateRequest): + try: + # Build extras from mobility_args + extras = { + "latitude": request.mobility_args.latitude, + "longitude": request.mobility_args.longitude, + } + await db._update_node_record( + str(request.device_args.uid), + str(request.device_args.idx), + str(request.network_args.ip), + str(request.network_args.port), + str(request.device_args.role), + request.network_args.neighbors, + extras, + str(request.timestamp), + str(request.scenario_args.federation), + str(request.federation_args.round), + str(request.scenario_args.name), + str(request.tracking_args.run_hash), + bool(request.device_args.malicious), + ) + return {"message": "Node updated successfully"} + except Exception as e: + logging.exception(f"Error updating node: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.NODES_REMOVE) +async def remove_nodes_by_federation_id_endpoint(federation_id: str): + try: + await db._remove_nodes_by_federation_id(federation_id) + return {"message": f"Nodes for federation {federation_id} removed successfully"} + except Exception as e: + logging.exception(f"Error removing nodes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +# Notes +@app.get(Routes.NOTES_BY_FEDERATION_ID) +async def get_notes_by_federation_id( + federation_id: str +): + try: + notes_record = await db._get_notes(federation_id) + return notes_record + except Exception as e: + logging.exception(f"Error obtaining notes for federation {federation_id}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.NOTES_UPDATE) +async def update_notes_by_scenario_name(federation_id: str, request: NotesUpdateRequest): + try: + await db._save_notes(federation_id ,**request.model_dump()) + return {"message": f"Notes for federation {federation_id} updated successfully"} + except Exception as e: + logging.exception(f"Error updating notes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.NOTES_REMOVE) +async def remove_notes_by_federation_id_endpoint(federation_id: str): + try: + await db._remove_note(federation_id) + return {"message": f"Notes for federation {federation_id} removed successfully"} + except Exception as e: + logging.exception(f"Error removing notes: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +# Users +@app.get(Routes.USER_LIST) +async def list_users_controller(request: ListUsersRequest = Depends()): + try: + return {"users": await db._list_users(request.all_info)} + except Exception as e: + logging.exception(f"Error retrieving users: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving users: {e}") + + +@app.get(Routes.USER_BY_FEDERATION_ID) +async def get_user_by_federation_id_endpoint( + federation_id: str +): + try: + user = await db._get_user_by_federation_id(federation_id) + return user + except Exception as e: + logging.exception(f"Error obtaining user for federation {federation_id}: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@app.post(Routes.USER_ADD) +async def add_user_controller(request: UserAddRequest): + try: + await db._add_user(**request.model_dump()) + return {"detail": "User added successfully"} + except Exception as e: + logging.exception(f"Error adding user: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error adding user: {e}") + + +@app.post(Routes.USER_DELETE) +async def remove_user_controller(request: UserDeleteRequest): + try: + await db._delete_user_from_db(request.user) + return {"detail": "User deleted successfully"} + except Exception as e: + logging.exception(f"Error deleting user: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error deleting user: {e}") + + +@app.post(Routes.USER_UPDATE) +async def update_user_controller(request: UserUpdateRequest): + try: + await db._update_user(**request.model_dump()) + return {"detail": "User updated successfully"} + except Exception as e: + logging.exception(f"Error updating user: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error updating user: {e}") + + +@app.post(Routes.USER_VERIFY) +async def verify_user_controller(request: UserVerifyRequest): + try: + auth = await db._verify(**request.model_dump()) + if auth: + return auth + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + except HTTPException as e: + # Propagate intended HTTP errors (e.g., 401) without wrapping + raise e + except Exception as e: + logging.exception(f"Error verifying user: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error verifying user: {e}") diff --git a/nebula/database/pgweb/Dockerfile b/nebula/database/pgweb/Dockerfile new file mode 100644 index 000000000..c9bb132e0 --- /dev/null +++ b/nebula/database/pgweb/Dockerfile @@ -0,0 +1 @@ +FROM sosedoff/pgweb diff --git a/nebula/database/utils_requests.py b/nebula/database/utils_requests.py new file mode 100644 index 000000000..59dde8244 --- /dev/null +++ b/nebula/database/utils_requests.py @@ -0,0 +1,192 @@ +from typing import Any, Dict, List + +from pydantic import BaseModel, confloat, conint + + +class Routes: + # Scenarios + INIT = "/" + UPDATE = "/scenarios/{federation_id}/update" + STOP = "/scenarios/{federation_id}/stop" + REMOVE = "/scenarios/{federation_id}/remove" + FINISH = "/scenarios/{federation_id}/set_status_to_finished" + RUNNING = "/scenarios/running" + CHECK_SCENARIO = "/scenarios/check/{user}/{role}/{federation_id}" + GET_SCENARIOS_BY_USER = "/scenarios/{user}/{role}" + GET_SCENARIOS_BY_SCENARIO_NAME = "/scenarios/{federation_id}" + + # Nodes + NODES_BY_FEDERATION_ID = "/nodes/{federation_id}" + NODES_UPDATE = "/nodes/update" + NODES_REMOVE = "/nodes/{federation_id}/remove" + + # Notes + NOTES_BY_FEDERATION_ID = "/notes/{federation_id}" + NOTES_UPDATE = "/notes/{federation_id}/update" + NOTES_REMOVE = "/notes/{federation_id}/remove" + + # Users + USER_LIST = "/user/list" + USER_BY_FEDERATION_ID = "/user/{federation_id}" + USER_ADD = "/user/add" + USER_DELETE = "/user/delete" + USER_UPDATE = "/user/update" + USER_VERIFY = "/user/verify" + + +class ScenarioUpdateRequest(BaseModel): + alias: str + scenario_name: str + start_time: str + end_time: str + scenario: Dict[str, Any] + status: str + username: str + + +class ScenarioStopRequest(BaseModel): + all: bool = False + + +class ScenarioFinishRequest(BaseModel): + all: bool = False + + +class NotesUpdateRequest(BaseModel): + notes: str + + + + + +class UserAddRequest(BaseModel): + user: str + password: str + role: str + + +class UserDeleteRequest(BaseModel): + user: str + + +class UserUpdateRequest(BaseModel): + user: str + password: str + role: str + + +class UserVerifyRequest(BaseModel): + user: str + password: str + + +# Nodes update payload +class DeviceArgs(BaseModel): + uid: str + idx: int + role: str + malicious: bool + + +class NetworkArgs(BaseModel): + ip: str + port: conint(ge=1, le=65535) # type: ignore[valid-type] + neighbors: List[Any] + + +class MobilityArgs(BaseModel): + latitude: confloat(ge=-90, le=90) # type: ignore[valid-type] + longitude: confloat(ge=-180, le=180) # type: ignore[valid-type] + + +class TrackingArgs(BaseModel): + run_hash: str + + +class FederationArgs(BaseModel): + round: int + + +class ScenarioArgs(BaseModel): + federation: str + name: str + + +class NodesUpdateRequest(BaseModel): + device_args: DeviceArgs + network_args: NetworkArgs + mobility_args: MobilityArgs + tracking_args: TrackingArgs + federation_args: FederationArgs + scenario_args: ScenarioArgs + timestamp: str + +class GetScenariosRequest(BaseModel): + user: str + role: str + + +class GetRunningScenarioRequest(BaseModel): + get_all: bool = False + + +class CheckScenarioRequest(BaseModel): + user: str + role: str + federation_id: str + + +class ListUsersRequest(BaseModel): + all_info: bool = False + + + + +def factory_requests_path(resource: str, user: str = "", role: str = "", federation_id: str = "") -> str: + if resource == "init": + return Routes.INIT + elif resource == "update": + return Routes.UPDATE + elif resource == "stop": + return Routes.STOP + elif resource == "remove": + return Routes.REMOVE + elif resource == "finish": + return Routes.FINISH + elif resource == "running": + return Routes.RUNNING + elif resource == "check_scenario": + return Routes.CHECK_SCENARIO.format(user=user, role=role, federation_id=federation_id) + elif resource == "get_scenarios_by_user": + return Routes.GET_SCENARIOS_BY_USER.format(user=user, role=role) + elif resource == "get_scenarios_by_scenario_name": + return Routes.GET_SCENARIOS_BY_SCENARIO_NAME.format(federation_id=federation_id) + # Nodes + elif resource == "get_nodes_by_scenario_name": + return Routes.NODES_BY_FEDERATION_ID.format(federation_id=federation_id) + elif resource == "update_nodes": + return Routes.NODES_UPDATE + elif resource == "remove_nodes": + return Routes.NODES_REMOVE.format(federation_id=federation_id) + # Notes + elif resource == "get_notes_by_scenario_name": + return Routes.NOTES_BY_FEDERATION_ID.format(federation_id=federation_id) + elif resource == "update_notes": + return Routes.NOTES_UPDATE + elif resource == "remove_notes": + return Routes.NOTES_REMOVE.format(federation_id=federation_id) + # Users + elif resource == "list_users": + return Routes.USER_LIST + elif resource == "get_user_by_scenario_name": + return Routes.USER_BY_FEDERATION_ID.format(federation_id=federation_id) + elif resource == "add_user": + return Routes.USER_ADD + elif resource == "delete_user": + return Routes.USER_DELETE + elif resource == "update_user": + return Routes.USER_UPDATE + elif resource == "verify_user": + return Routes.USER_VERIFY + else: + raise Exception(f"resource not found: {resource}") diff --git a/nebula/frontend/app.py b/nebula/frontend/app.py index 8cb38d9cb..59c836a1b 100755 --- a/nebula/frontend/app.py +++ b/nebula/frontend/app.py @@ -30,8 +30,9 @@ class Settings: controller_port (int): Port on which the Nebula controller listens (default: 5050). resources_threshold (float): Threshold for resource usage alerts (default: 0.0). port (int): Port for the Nebula frontend service (default: 6060). - production (bool): Whether the application is running in production mode. - advanced_analytics (bool): Whether advanced analytics features are enabled. + env_tag (str): Tag for the environment (e.g., 'dev', 'prod'). + prefix_tag (str): Tag for the deployment prefix (e.g., 'dev', 'prod'). + user_tag (str): Tag for the user (e.g., 'admin', 'user'). host_platform (str): Underlying host operating platform (e.g., 'unix'). log_dir (str): Directory path where application logs are stored. config_dir (str): Directory path for general configuration files. @@ -49,8 +50,9 @@ class Settings: controller_port: int = os.environ.get("NEBULA_CONTROLLER_PORT", 5050) resources_threshold: float = 0.0 port: int = os.environ.get("NEBULA_FRONTEND_PORT", 6060) - production: bool = os.environ.get("NEBULA_PRODUCTION", "False") == "True" - advanced_analytics: bool = os.environ.get("NEBULA_ADVANCED_ANALYTICS", "False") == "True" + env_tag: str = os.environ.get("NEBULA_ENV_TAG", "dev") + prefix_tag: str = os.environ.get("NEBULA_PREFIX_TAG", "dev") + user_tag: str = os.environ.get("NEBULA_USER_TAG", os.environ.get("USER", "unknown")) host_platform: str = os.environ.get("NEBULA_HOST_PLATFORM", "unix") log_dir: str = os.environ.get("NEBULA_LOGS_DIR") config_dir: str = os.environ.get("NEBULA_CONFIG_DIR") @@ -116,16 +118,8 @@ class Settings: logging.info(f"🚀 Starting Nebula Frontend on port {settings.port}") -logging.info(f"NEBULA_PRODUCTION: {settings.production}") - -if "SECRET_KEY" not in os.environ: - logging.info("Generating SECRET_KEY") - os.environ["SECRET_KEY"] = os.urandom(24).hex() - logging.info(f"Saving SECRET_KEY to {settings.env_file}") - with open(settings.env_file, "a") as f: - f.write(f"SECRET_KEY={os.environ['SECRET_KEY']}\n") -else: - logging.info("SECRET_KEY already set") +logging.info(f"NEBULA_PRODUCTION: {settings.env_tag == 'prod'}") +logging.info(f"NEBULA_DEPLOYMENT_PREFIX: {settings.prefix_tag}") app = FastAPI() app.add_middleware( @@ -289,9 +283,11 @@ def add_global_context(request: Request): Returns: dict[str, bool]: is_production: Flag indicating if the application is running in production mode. + prefix: The prefix of the application. """ return { - "is_production": settings.production, + "is_production": settings.env_tag == "prod", + "prefix": settings.prefix_tag, } @@ -631,7 +627,7 @@ async def get_scenarios(user, role): return await controller_get(url) -async def scenario_update_record(scenario_name, start_time, end_time, scenario, status, role, username): +async def scenario_update_record(scenario_name, start_time, end_time, scenario, status, username): """ Update the record of a scenario's execution status on the controller. @@ -641,7 +637,6 @@ async def scenario_update_record(scenario_name, start_time, end_time, scenario, end_time (str): ISO-formatted end timestamp. scenario (Any): Scenario payload or identifier. status (str): New status value (e.g., 'running', 'finished'). - role (str): Role associated with the scenario. username (str): User who ran or updated the scenario. Raises: @@ -654,7 +649,6 @@ async def scenario_update_record(scenario_name, start_time, end_time, scenario, "end_time": end_time, "scenario": scenario, "status": status, - "role": role, "username": username, } await controller_post(url, data) @@ -691,7 +685,7 @@ async def remove_scenario_by_name(scenario_name): await controller_post(url, data) -async def check_scenario_with_role(role, scenario_name): +async def check_scenario_with_role(role, scenario_name, user): """ Check if a specific scenario is allowed for the session's role. @@ -705,7 +699,7 @@ async def check_scenario_with_role(role, scenario_name): Raises: HTTPException: If the underlying HTTP GET request fails. """ - url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/check/{role}/{scenario_name}" + url = f"http://{settings.controller_host}:{settings.controller_port}/scenarios/check/{user}/{role}/{scenario_name}" check_data = await controller_get(url) return check_data.get("allowed", False) @@ -1599,6 +1593,7 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session # Calculate initial status based on timestamp timestamp = datetime.datetime.strptime(node["timestamp"], "%Y-%m-%d %H:%M:%S.%f") is_online = (datetime.datetime.now() - timestamp) <= datetime.timedelta(seconds=25) + mobility_args = json.loads(node["extras"]) formatted_nodes.append({ "uid": node["uid"], @@ -1606,9 +1601,9 @@ async def nebula_dashboard_monitor(scenario_name: str, request: Request, session "ip": node["ip"], "port": node["port"], "role": node["role"], - "neighbors": node["neighbors"], - "latitude": node["latitude"], - "longitude": node["longitude"], + "neighbors": " ".join(node["neighbors"]), + "latitude": mobility_args["latitude"], + "longitude": mobility_args["longitude"], "timestamp": node["timestamp"], "federation": node["federation"], "round": str(node["round"]), @@ -1710,8 +1705,7 @@ async def nebula_update_node(scenario_name: str, request: Request): "ip": config["network_args"]["ip"], "port": str(config["network_args"]["port"]), "role": config["device_args"]["role"], - "malicious": config["device_args"]["malicious"], - "neighbors": config["network_args"]["neighbors"], + "neighbors": " ".join(config["network_args"]["neighbors"]), "latitude": config["mobility_args"]["latitude"], "longitude": config["mobility_args"]["longitude"], "timestamp": config["timestamp"], @@ -1856,8 +1850,6 @@ async def remove_scenario(scenario_name=None, user=None): user_data = user_data_store[user] - if settings.advanced_analytics: - logging.info("Advanced analytics enabled") # Remove registered nodes and conditions user_data.nodes_registration.pop(scenario_name, None) await remove_nodes_by_scenario_name(scenario_name) @@ -1890,7 +1882,7 @@ async def nebula_relaunch_scenario( if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not await check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name, session["user"]): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) scenario_path = FileUtils.check_path(settings.config_dir, os.path.join(scenario_name, "scenario.json")) @@ -1931,7 +1923,7 @@ async def nebula_remove_scenario(scenario_name: str, session: dict = Depends(get if session["role"] == "demo": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) elif session["role"] == "user": - if not await check_scenario_with_role(session["role"], scenario_name): + if not await check_scenario_with_role(session["role"], scenario_name, session["user"]): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) await remove_scenario(scenario_name, session["user"]) return RedirectResponse(url="/platform/dashboard") @@ -1939,14 +1931,6 @@ async def nebula_remove_scenario(scenario_name: str, session: dict = Depends(get raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) -if settings.advanced_analytics: - logging.info("Advanced analytics enabled") -else: - logging.info("Advanced analytics disabled") - - # TENSORBOARD START - - @app.get("/platform/dashboard/statistics/", response_class=HTMLResponse) @app.get("/platform/dashboard/{scenario_name}/statistics/", response_class=HTMLResponse) async def nebula_dashboard_statistics(request: Request, scenario_name: str = None): @@ -2168,7 +2152,8 @@ async def assign_available_gpu(scenario_data, role): running_gpus = [] # Obtain associated gpus of the running scenarios for scenario in running_scenarios: - scenario_gpus = json.loads(scenario["gpu_id"]) + config = json.loads(scenario["config"]) + scenario_gpus = config.get("gpu_id", []) # Obtain the list of gpus in use without duplicates for gpu in scenario_gpus: if gpu not in running_gpus: diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index ca1d1cfd6..0bca1feb5 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -104,7 +104,7 @@ }, "tracking_args": { "enable_remote_tracking": false, - "local_tracking": "basic", + "local_tracking": "default", "log_dir": "/Users/enrique/Documents/nebula/app/logs", "config_dir": "/Users/enrique/Documents/nebula/app/config", "run_hash": "" diff --git a/nebula/frontend/start_services.sh b/nebula/frontend/start_services.sh index b2d26e8a2..facac33b1 100755 --- a/nebula/frontend/start_services.sh +++ b/nebula/frontend/start_services.sh @@ -27,11 +27,7 @@ else uvicorn app:app --uds /tmp/$NEBULA_SOCK --log-level info --proxy-headers --forwarded-allow-ips "*" & fi -if [ "$NEBULA_ADVANCED_ANALYTICS" = "False" ]; then - echo "Starting Tensorboard analytics" - tensorboard --host 0.0.0.0 --port 8080 --logdir $NEBULA_LOGS_DIR --window_title "NEBULA Statistics" --reload_interval 30 --max_reload_threads 10 --reload_multifile true & -else - echo "Advanced analytics are enabled" -fi +echo "Starting Tensorboard analytics" +tensorboard --host 0.0.0.0 --port 8080 --logdir $NEBULA_LOGS_DIR --window_title "NEBULA Statistics" --reload_interval 30 --max_reload_threads 10 --reload_multifile true & tail -f /dev/null diff --git a/nebula/frontend/static/css/style.css b/nebula/frontend/static/css/style.css index 262732d28..0d2644142 100755 --- a/nebula/frontend/static/css/style.css +++ b/nebula/frontend/static/css/style.css @@ -633,8 +633,8 @@ hr.styled { .container { max-width: 100%; - padding-right: 10px; - padding-left: 10px; + padding-right: 25px; + padding-left: 25px; } } diff --git a/nebula/physical/api.py b/nebula/physical/api.py index 6d44428e7..3e7da1bec 100644 --- a/nebula/physical/api.py +++ b/nebula/physical/api.py @@ -350,7 +350,7 @@ def run(): if TRAINING_PROC and TRAINING_PROC.poll() is None: _json_abort(409, "Training already running") - cmd = ["python", "/home/dietpi/prueba/nebula/nebula/node.py", json_files[0]] + cmd = ["python", "/home/dietpi/test/nebula/nebula/node.py", json_files[0]] TRAINING_PROC = subprocess.Popen(cmd) return jsonify(pid=TRAINING_PROC.pid, state="running") @@ -379,8 +379,8 @@ def setup_new_run(): Expected multipart-form fields ------------------------------- - * **config** – JSON with scenario, network and security arguments - * **global_test** – shared evaluation dataset (`*.h5`) + * **config** – JSON with scenario, network and security arguments + * **global_test** – shared evaluation dataset (`*.h5`) * **train_set** – participant-specific training dataset (`*.h5`) The function rewrites paths inside *config*, validates neighbour IPs @@ -489,4 +489,4 @@ def setup_new_run(): # ----------------------------------------------------------------------------- if __name__ == "__main__": # Local testing: python main.py - app.run(host="0.0.0.0", port=8000, debug=False) \ No newline at end of file + app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/nebula/utils.py b/nebula/utils.py index cfc7a558b..11f127c8b 100644 --- a/nebula/utils.py +++ b/nebula/utils.py @@ -2,8 +2,24 @@ import os import socket +import aiohttp +import aiohttp import docker +import re +from typing import Optional + +from fastapi import HTTPException +from aiohttp import ClientConnectorError +from aiohttp.client_exceptions import ClientError +import asyncio +import re +from typing import Optional + +from fastapi import HTTPException +from aiohttp import ClientConnectorError +from aiohttp.client_exceptions import ClientError +import asyncio class FileUtils: """ @@ -33,6 +49,27 @@ def check_path(cls, base_path, relative_path): raise Exception("Not allowed") return full_path + @classmethod + def update_env_file(cls, env_file, key, value): + """ + Update or add a key-value pair in the .env file. + """ + import re + lines = [] + if os.path.exists(env_file): + with open(env_file, "r") as f: + lines = f.readlines() + key_found = False + for i, line in enumerate(lines): + if re.match(rf"^{key}=.*", line): + lines[i] = f"{key}={value}\n" + key_found = True + break + if not key_found: + lines.append(f"{key}={value}\n") + with open(env_file, "w") as f: + f.writelines(lines) + class SocketUtils: """ @@ -75,7 +112,6 @@ def find_free_port(cls, start_port=49152, end_port=65535): return port return None - class DockerUtils: """ Utility class for Docker operations such as creating networks, @@ -174,7 +210,7 @@ def check_docker_by_prefix(cls, prefix): for container in containers: if container.name.startswith(prefix): return True - + return False except docker.errors.APIError: @@ -182,92 +218,149 @@ def check_docker_by_prefix(cls, prefix): except Exception: logging.exception("Unexpected error") - @classmethod - def remove_docker_network(cls, network_name): + +class LoggerUtils: + + @staticmethod + def configure_logger( + name: Optional[str] = None, + log_file: Optional[str] = None, + level: int = logging.INFO, + console: bool = True, + strip_ansi: bool = True, + file_mode: str = "w", + log_format: str = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", + date_format: str = "%Y-%m-%d %H:%M:%S", + ) -> logging.Logger: """ - Removes a Docker network by name. + Configure and return a logger with optional console and file output. Args: - network_name (str): Name of the Docker network to remove. + name (str): Logger name. If None, the root logger is used. + log_file (str): Path to the log file. + level (int): Logging level (DEBUG, INFO, etc). + console (bool): If True, output is also printed to the console. + strip_ansi (bool): Placeholder for future ANSI stripping support. + file_mode (str): File mode for the log file ('a' for append, 'w' for overwrite). + log_format (str): Format for log messages. + date_format (str): Format for timestamps. Returns: - None + logging.Logger: Configured logger instance. """ - try: - # Connect to Docker - client = docker.from_env() + logger = logging.getLogger(name) + logger.setLevel(level) - # Get the network by name - network = client.networks.get(network_name) + # Prevent duplicate handler setup + if getattr(logger, "_is_configured", False): + return logger - # Remove the network - network.remove() + formatter = logging.Formatter(fmt=log_format, datefmt=date_format) - logging.info(f"Network {network_name} removed successfully.") - except docker.errors.NotFound: - logging.exception(f"Network {network_name} not found.") - except docker.errors.APIError: - logging.exception("Error interacting with Docker") - except Exception: - logging.exception("Unexpected error") + if log_file: + os.makedirs(os.path.dirname(log_file), exist_ok=True) + fh = logging.FileHandler(log_file, mode=file_mode) + fh.setLevel(level) + fh.setFormatter(formatter) + logger.addHandler(fh) - @classmethod - def remove_docker_networks_by_prefix(cls, prefix): + if console: + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + # Mark this logger as configured to avoid re-adding handlers + logger._is_configured = True + logger.propagate = False + + return logger + +class APIUtils(): + + @staticmethod + async def retry_with_backoff(func, *args, max_retries=5, initial_delay=1): """ - Removes all Docker networks whose names start with the given prefix. + Retry a function with exponential backoff. Args: - prefix (str): Prefix string to match network names. + func: The async function to retry + *args: Arguments to pass to the function + max_retries: Maximum number of retry attempts + initial_delay: Initial delay between retries in seconds Returns: - None + The result of the function if successful + + Raises: + The last exception if all retries fail """ - try: - # Connect to Docker - client = docker.from_env() + delay = initial_delay + last_exception = None + + for attempt in range(max_retries): + try: + return await func(*args) + except (ClientConnectorError, ClientError) as e: + last_exception = e + if attempt < max_retries - 1: + logging.warning(f"Connection attempt {attempt + 1} failed: {str(e)}. Retrying in {delay} seconds...") + await asyncio.sleep(delay) + delay *= 2 # Exponential backoff + else: + logging.error(f"All {max_retries} connection attempts failed") + raise last_exception - # List all networks - networks = client.networks.list() + @staticmethod + async def get(url, params=None): + """ + Fetch JSON data from a remote controller endpoint via asynchronous HTTP GET. - # Filter and remove networks with names starting with the prefix - for network in networks: - if network.name.startswith(prefix): - network.remove() - logging.info(f"Network {network.name} removed successfully.") + Parameters: + url (str): The full URL of the controller API endpoint. + params (dict, optional): A dictionary of query parameters to be sent with the request. - except docker.errors.NotFound: - logging.info(f"One or more networks with prefix {prefix} not found.") - except docker.errors.APIError: - logging.info("Error interacting with Docker") - except Exception: - logging.info("Unexpected error") + Returns: + Any: Parsed JSON response when the HTTP status code is 200. - @classmethod - def remove_containers_by_prefix(cls, prefix): + Raises: + HTTPException: If the response status is not 200, raises with the response status code and an error detail. """ - Removes all Docker containers whose names start with the given prefix. - Containers are forcibly removed even if they are running. - Args: - prefix (str): Prefix string to match container names. + async def _get(): + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params) as response: + if response.status == 200: + return await response.json() + else: + detail = await response.text() + raise HTTPException(status_code=response.status, detail=detail) - Returns: - None + return await APIUtils.retry_with_backoff(_get) + + @staticmethod + async def post(url, data=None): """ - try: - # Connect to Docker client - client = docker.from_env() + Asynchronously send a JSON payload via HTTP POST to a controller endpoint and parse the response. - containers = client.containers.list(all=True) # `all=True` to include stopped containers + Parameters: + url (str): The full URL of the controller API endpoint. + data (Any, optional): JSON-serializable payload to include in the POST request (default: None). - # Iterate through containers and remove those with the matching prefix - for container in containers: - if container.name.startswith(prefix): - logging.info(f"Removing container: {container.name}") - container.remove(force=True) # force=True to stop and remove if running - logging.info(f"Container {container.name} removed successfully.") + Returns: + Any: Parsed JSON response when the HTTP status code is 200. - except docker.errors.APIError: - logging.exception("Error interacting with Docker") - except Exception: - logging.exception("Unexpected error") + Raises: + HTTPException: If the response status is not 200, with the status code and an error detail. + """ + + async def _post(): + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as response: + if response.status == 200: + return await response.json() + else: + detail = await response.text() + raise HTTPException(status_code=response.status, detail=detail) + + return await APIUtils.retry_with_backoff(_post) diff --git a/pyproject.toml b/pyproject.toml index d54ff3dcd..d1bf059c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,6 @@ docs = [ controller = [ "aiohttp==3.10.5", "aiosqlite==0.20.0", - "argon2-cffi==23.1.0", "docker==7.1.0", "fastapi[all]==0.114.0", "gunicorn==23.0.0", @@ -80,6 +79,13 @@ controller = [ "scikit-image==0.24.0", "scikit-learn==1.5.1", ] +database = [ + "argon2-cffi==23.1.0", + "asyncpg==0.30.0", + "psycopg2-binary==2.9.10", + "passlib==1.7.4", + "fastapi[all]==0.114.0", +] core = [ "aiohttp==3.10.5", "async-timeout==4.0.3",