diff --git a/device-discovery/device_discovery/client.py b/device-discovery/device_discovery/client.py index af416a81..feac83d6 100644 --- a/device-discovery/device_discovery/client.py +++ b/device-discovery/device_discovery/client.py @@ -4,8 +4,9 @@ import logging import threading +from typing import Any -from netboxlabs.diode.sdk import DiodeClient, DiodeDryRunClient +from netboxlabs.diode.sdk import DiodeClient, DiodeDryRunClient, DiodeOTLPClient from device_discovery.translate import translate_data from device_discovery.version import version_semver @@ -82,7 +83,7 @@ def init_client( app_name=f"{prefix}/{APP_NAME}" if prefix else APP_NAME, output_dir=dry_run_output_dir, ) - else: + elif client_id is not None and client_secret is not None: self.diode_client = DiodeClient( target=target, app_name=f"{prefix}/{APP_NAME}" if prefix else APP_NAME, @@ -90,14 +91,21 @@ def init_client( client_id=client_id, client_secret=client_secret, ) + else: + logger.debug("Initializing Diode OTLP client") + self.diode_client = DiodeOTLPClient( + target=target, + app_name=f"{prefix}/{APP_NAME}" if prefix else APP_NAME, + app_version=APP_VERSION, + ) - def ingest(self, hostname: str, data: dict): + def ingest(self, metadata: dict[str, Any] | None, data: dict): """ Ingest data using the Diode client after translating it. Args: ---- - hostname (str): The device hostname. + metadata (dict[str, Any] | None): Metadata to attach to the ingestion request. data (dict): The data to be ingested. Raises: @@ -109,9 +117,17 @@ def ingest(self, hostname: str, data: dict): raise ValueError("Diode client not initialized") with self._lock: - response = self.diode_client.ingest(translate_data(data)) + translated_entities = translate_data(data) + request_metadata = metadata or {} + response = self.diode_client.ingest( + entities=translated_entities, metadata=request_metadata + ) + + hostname = request_metadata.get("hostname") or "unknown-host" if response.errors: - logger.error(f"ERROR ingestion failed for {hostname} : {response.errors}") + logger.error( + f"ERROR ingestion failed for {hostname} : {response.errors}" + ) else: logger.info(f"Hostname {hostname}: Successful ingestion") diff --git a/device-discovery/device_discovery/main.py b/device-discovery/device_discovery/main.py index c59a25f3..794f22ac 100644 --- a/device-discovery/device_discovery/main.py +++ b/device-discovery/device_discovery/main.py @@ -136,8 +136,6 @@ def main(): name for name, val in [ ("--diode-target", args.diode_target), - ("--diode-client-id", args.diode_client_id), - ("--diode-client-secret", args.diode_client_secret), ] if not val ] diff --git a/device-discovery/device_discovery/policy/runner.py b/device-discovery/device_discovery/policy/runner.py index e9bdcae1..b3bf9be2 100644 --- a/device-discovery/device_discovery/policy/runner.py +++ b/device-discovery/device_discovery/policy/runner.py @@ -187,7 +187,8 @@ def _collect_device_data( logger.error( f"Policy {self.name}, Hostname {sanitized_hostname}: Error getting VLANs: {e}" ) - Client().ingest(scope.hostname, data) + metadata = {"policy_name": self.name, "hostname": sanitized_hostname} + Client().ingest(metadata, data) discovery_success = get_metric("discovery_success") if discovery_success: discovery_success.add(1, {"policy": self.name}) diff --git a/device-discovery/pyproject.toml b/device-discovery/pyproject.toml index dc3a9929..39388129 100644 --- a/device-discovery/pyproject.toml +++ b/device-discovery/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "fastapi~=0.115", "httpx~=0.27", "napalm~=5.0", - "netboxlabs-diode-sdk~=1.1", + "netboxlabs-diode-sdk~=1.6", "pydantic~=2.9", "python-dotenv~=1.0", "uvicorn~=0.32", diff --git a/device-discovery/tests/policy/test_runner.py b/device-discovery/tests/policy/test_runner.py index 32641199..b3d2d18e 100644 --- a/device-discovery/tests/policy/test_runner.py +++ b/device-discovery/tests/policy/test_runner.py @@ -155,7 +155,11 @@ def test_run_device_with_discovered_driver(policy_runner, sample_scopes, sample_ # Verify driver discovery and ingestion mock_discover.assert_called_once_with(sample_scopes[0]) mock_ingest.assert_called_once() - data = mock_ingest.call_args[0][1] + metadata_arg, data = mock_ingest.call_args[0] + assert metadata_arg == { + "policy_name": policy_runner.name, + "hostname": sample_scopes[0].hostname, + } assert data["driver"] == "ios" assert data["device"] == {"model": "SampleModel"} assert data["interface"] == {"eth0": "up"} diff --git a/device-discovery/tests/test_client.py b/device-discovery/tests/test_client.py index ded814ac..f9e66265 100644 --- a/device-discovery/tests/test_client.py +++ b/device-discovery/tests/test_client.py @@ -59,6 +59,12 @@ def sample_data(): } +@pytest.fixture +def sample_metadata(): + """Sample metadata for testing ingestion.""" + return {"policy_name": "test-policy", "hostname": "router1"} + + @pytest.fixture def mock_version_semver(): """Mock the version_semver function.""" @@ -73,6 +79,13 @@ def mock_diode_client_class(): yield mock +@pytest.fixture +def mock_diode_otlp_client_class(): + """Mock the DiodeOTLPClient class.""" + with patch("device_discovery.client.DiodeOTLPClient") as mock: + yield mock + + def test_init_client(mock_diode_client_class, mock_version_semver): """Test the initialization of the Diode client.""" client = Client() @@ -92,7 +105,7 @@ def test_init_client(mock_diode_client_class, mock_version_semver): ) -def test_ingest_success(mock_diode_client_class, sample_data): +def test_ingest_success(mock_diode_client_class, sample_data, sample_metadata): """Test successful data ingestion.""" client = Client() client.init_client( @@ -101,18 +114,20 @@ def test_ingest_success(mock_diode_client_class, sample_data): mock_diode_instance = mock_diode_client_class.return_value mock_diode_instance.ingest.return_value.errors = [] - hostname = sample_data["device"]["hostname"] - + metadata = sample_metadata with patch( "device_discovery.client.translate_data", return_value=translate_data(sample_data), ) as mock_translate_data: - client.ingest(hostname, sample_data) + client.ingest(metadata, sample_data) mock_translate_data.assert_called_once_with(sample_data) - mock_diode_instance.ingest.assert_called_once() + mock_diode_instance.ingest.assert_called_once_with( + entities=mock_translate_data.return_value, + metadata=metadata, + ) -def test_ingest_failure(mock_diode_client_class, sample_data): +def test_ingest_failure(mock_diode_client_class, sample_data, sample_metadata): """Test data ingestion with errors.""" client = Client() client.init_client( @@ -124,25 +139,27 @@ def test_ingest_failure(mock_diode_client_class, sample_data): mock_diode_instance = mock_diode_client_class.return_value mock_diode_instance.ingest.return_value.errors = ["Error1", "Error2"] - hostname = sample_data["device"]["hostname"] - + metadata = sample_metadata with patch( "device_discovery.client.translate_data", return_value=translate_data(sample_data), ) as mock_translate_data: - client.ingest(hostname, sample_data) + client.ingest(metadata, sample_data) mock_translate_data.assert_called_once_with(sample_data) - mock_diode_instance.ingest.assert_called_once() + mock_diode_instance.ingest.assert_called_once_with( + entities=mock_translate_data.return_value, + metadata=metadata, + ) assert len(mock_diode_instance.ingest.return_value.errors) > 0 -def test_ingest_without_initialization(): +def test_ingest_without_initialization(sample_metadata): """Test ingestion without client initialization raises ValueError.""" Client._instance = None # Reset the Client singleton instance client = Client() with pytest.raises(ValueError, match="Diode client not initialized"): - client.ingest("", {}) + client.ingest(sample_metadata, {}) def test_client_dry_run(tmp_path, sample_data): @@ -154,7 +171,8 @@ def test_client_dry_run(tmp_path, sample_data): dry_run_output_dir=tmp_path, ) hostname = sample_data["device"]["hostname"] - client.ingest(hostname, sample_data) + metadata = {"policy_name": "dry-run-policy", "hostname": hostname} + client.ingest(metadata, sample_data) files = list(tmp_path.glob("prefix_device-discovery*.json")) assert len(files) == 1 @@ -174,8 +192,24 @@ def test_client_dry_run_stdout(capsys, sample_data): ) hostname = sample_data["device"]["hostname"] - client.ingest(hostname, sample_data) + metadata = {"policy_name": "dry-run-policy", "hostname": hostname} + client.ingest(metadata, sample_data) captured = capsys.readouterr() assert sample_data["device"]["hostname"] in captured.out assert sample_data["interface"]["GigabitEthernet0/0"]["mac_address"] in captured.out + + +def test_init_client_uses_otlp_when_credentials_missing( + mock_diode_client_class, mock_diode_otlp_client_class, mock_version_semver +): + """Ensure init_client falls back to DiodeOTLPClient when credentials are not provided.""" + client = Client() + client.init_client(prefix="prefix", target="https://example.com") + + assert not mock_diode_client_class.called + mock_diode_otlp_client_class.assert_called_once_with( + target="https://example.com", + app_name="prefix/device-discovery", + app_version=mock_version_semver(), + ) diff --git a/worker/pyproject.toml b/worker/pyproject.toml index 3eded638..1e16a330 100644 --- a/worker/pyproject.toml +++ b/worker/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "croniter~=5.0", "fastapi~=0.115", "httpx~=0.27", - "netboxlabs-diode-sdk~=1.1", + "netboxlabs-diode-sdk~=1.6", "pydantic~=2.9", "uvicorn~=0.32", "PyYAML~=6.0", diff --git a/worker/tests/policy/test_runner.py b/worker/tests/policy/test_runner.py index 244d61e8..947b4845 100644 --- a/worker/tests/policy/test_runner.py +++ b/worker/tests/policy/test_runner.py @@ -16,7 +16,11 @@ @pytest.fixture def policy_runner(): """Fixture to create a PolicyRunner instance.""" - return PolicyRunner() + runner = PolicyRunner() + runner.metadata = Metadata( + name="test_backend", app_name="test_app", app_version="1.0" + ) + return runner @pytest.fixture @@ -71,6 +75,15 @@ def mock_diode_client(): mock_diode_client.return_value = mock_instance yield mock_diode_client + +@pytest.fixture +def mock_diode_otlp_client(): + """Fixture to mock the DiodeOTLPClient constructor.""" + with patch("worker.policy.runner.DiodeOTLPClient") as mock_diode_otlp_client: + mock_instance = MagicMock() + mock_diode_otlp_client.return_value = mock_instance + yield mock_diode_otlp_client + @pytest.fixture def mock_diode_dry_run_client(): """Fixture to mock the DiodeDryRunClient constructor.""" @@ -138,6 +151,28 @@ def test_setup_policy_runner_with_one_time_run( assert mock_start.called assert policy_runner.status == Status.RUNNING + +def test_setup_policy_runner_uses_otlp_client( + policy_runner, + sample_policy, + mock_load_class, + mock_diode_client, + mock_diode_otlp_client, +): + """Ensure setup falls back to DiodeOTLPClient when credentials are missing.""" + otlp_config = DiodeConfig(target="http://localhost:8080", prefix="test-prefix") + with patch.object(policy_runner.scheduler, "start") as mock_start, patch.object( + policy_runner.scheduler, "add_job" + ) as mock_add_job: + policy_runner.setup("policy1", otlp_config, sample_policy) + + mock_start.assert_called_once() + mock_add_job.assert_called_once() + + mock_load_class.assert_called_once() + assert not mock_diode_client.called + mock_diode_otlp_client.assert_called_once() + def test_setup_policy_runner_dry_run( policy_runner, sample_diode_dry_run_config, @@ -185,6 +220,29 @@ def test_run_success(policy_runner, sample_policy, mock_diode_client, mock_backe assert len(call_args) == 3 +def test_run_passes_metadata_to_ingest( + policy_runner, sample_policy, mock_diode_client, mock_backend +): + """Ensure run forwards policy/backend metadata to the Diode client.""" + policy_runner.name = "policy-meta" + policy_runner.metadata = Metadata( + name="custom_backend", app_name="custom", app_version="0.1" + ) + + entity = ingester_pb2.Entity() + entity.device.name = "device-1" + mock_backend.run.return_value = [entity] + mock_diode_client.ingest.return_value.errors = [] + + policy_runner.run(mock_diode_client, mock_backend, sample_policy) + + _, kwargs = mock_diode_client.ingest.call_args + assert kwargs["metadata"] == { + "policy_name": "policy-meta", + "worker_backend": "custom_backend", + } + + def test_run_ingestion_errors( policy_runner, sample_policy, diff --git a/worker/tests/test_main.py b/worker/tests/test_main.py index 119f4a2c..4a7cc334 100644 --- a/worker/tests/test_main.py +++ b/worker/tests/test_main.py @@ -7,7 +7,7 @@ import pytest -from worker.main import main +from worker.main import main, resolve_env_var @pytest.fixture @@ -169,3 +169,16 @@ def test_main_missing_policy(mock_parse_args): main() except Exception as e: assert str(e) == "Test Exit" + + +def test_resolve_env_var_expands_environment(monkeypatch): + """Ensure resolve_env_var expands placeholders using environment variables.""" + monkeypatch.setenv("MY_ENDPOINT", "grpc://localhost:4317") + assert resolve_env_var("${MY_ENDPOINT}") == "grpc://localhost:4317" + + +def test_resolve_env_var_returns_original(monkeypatch): + """Ensure resolve_env_var returns original string when expansion is not possible.""" + monkeypatch.delenv("NOT_DEFINED", raising=False) + assert resolve_env_var("plain-value") == "plain-value" + assert resolve_env_var("${NOT_DEFINED}") == "${NOT_DEFINED}" diff --git a/worker/worker/main.py b/worker/worker/main.py index 4a991ace..6220ec45 100644 --- a/worker/worker/main.py +++ b/worker/worker/main.py @@ -136,8 +136,6 @@ def main(): name for name, val in [ ("--diode-target", args.diode_target), - ("--diode-client-id", args.diode_client_id), - ("--diode-client-secret", args.diode_client_secret), ] if not val ] @@ -174,7 +172,7 @@ def main(): ) try: - if not config.dry_run: + if not config.dry_run and client_id is not None and client_secret is not None: DiodeClient( target=config.target, app_name="validate", diff --git a/worker/worker/policy/runner.py b/worker/worker/policy/runner.py index 2c9d34db..3872eeff 100644 --- a/worker/worker/policy/runner.py +++ b/worker/worker/policy/runner.py @@ -9,7 +9,7 @@ from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.date import DateTrigger -from netboxlabs.diode.sdk import DiodeClient, DiodeDryRunClient +from netboxlabs.diode.sdk import DiodeClient, DiodeDryRunClient, DiodeOTLPClient from netboxlabs.diode.sdk.diode.v1 import ingester_pb2 from worker.backend import Backend, load_class @@ -63,7 +63,7 @@ def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): app_name=app_name, output_dir=diode_config.dry_run_output_dir, ) - else: + elif diode_config.client_id is not None and diode_config.client_secret is not None: client = DiodeClient( target=diode_config.target, app_name=app_name, @@ -71,6 +71,13 @@ def setup(self, name: str, diode_config: DiodeConfig, policy: Policy): client_id=diode_config.client_id, client_secret=diode_config.client_secret, ) + else: + logger.debug("Initializing Diode OTLP client") + client = DiodeOTLPClient( + target=diode_config.target, + app_name=app_name, + app_version=metadata.app_version, + ) self.metadata = metadata self.policy = policy @@ -120,13 +127,17 @@ def run( exec_start_time = time.perf_counter() try: entities = backend.run(self.name, policy) + metadata = { + "policy_name": self.name, + "worker_backend": self.metadata.name, + } for chunk_num, entity_chunk in enumerate(self._create_message_chunks(entities), 1): chunk_size_mb = self._estimate_message_size(entity_chunk) / (1024 * 1024) logger.debug( f"Ingesting chunk {chunk_num} with {len(entity_chunk)} entities (~{chunk_size_mb:.2f} MB)" ) - response = client.ingest(entities=entity_chunk) + response = client.ingest(entities=entity_chunk, metadata=metadata) if response.errors: raise RuntimeError(f"Chunk {chunk_num} ingestion failed: {response.errors}") logger.debug(f"Chunk {chunk_num} ingested successfully")