diff --git a/gcm/health_checks/checks/__init__.py b/gcm/health_checks/checks/__init__.py index f72ff8d..d60c8fa 100644 --- a/gcm/health_checks/checks/__init__.py +++ b/gcm/health_checks/checks/__init__.py @@ -6,6 +6,7 @@ from gcm.health_checks.checks.check_dcgmi import check_dcgmi from gcm.health_checks.checks.check_ethlink import check_ethlink from gcm.health_checks.checks.check_hca import check_hca +from gcm.health_checks.checks.check_ib_counters import check_ib_counters from gcm.health_checks.checks.check_ibstat import check_ib from gcm.health_checks.checks.check_ipmitool import check_ipmitool from gcm.health_checks.checks.check_nccl import check_nccl @@ -44,4 +45,5 @@ "check_blockdev", "check_ethlink", "check_sensors", + "check_ib_counters", ] diff --git a/gcm/health_checks/checks/check_ib_counters.py b/gcm/health_checks/checks/check_ib_counters.py new file mode 100644 index 0000000..f898d5c --- /dev/null +++ b/gcm/health_checks/checks/check_ib_counters.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +"""Check InfiniBand port error and throughput counters via sysfs. + +Reads counters from /sys/class/infiniband/{device}/ports/{port}/counters/ +and alerts when error counters exceed configurable thresholds. This is the +runtime complement to check_iblink: where check_iblink validates link +*presence*, check_ib_counters detects performance-degrading conditions that +silently hurt distributed training (NCCL AllReduce, FSDP, etc.). +""" + +import logging +import os +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import Optional, Protocol + +import click +from gcm.health_checks.check_utils.output_utils import CheckOutput, Metric +from gcm.health_checks.check_utils.runtime import HealthCheckRuntime +from gcm.health_checks.click import ( + common_arguments, + telemetry_argument, + timeout_argument, +) +from gcm.health_checks.types import CHECK_TYPE, CheckEnv, ExitCode, LOG_LEVEL +from gcm.monitoring.click import heterogeneous_cluster_v1_option +from gcm.monitoring.features.gen.generated_features_healthchecksfeatures import ( + FeatureValueHealthChecksFeatures, +) +from gcm.schemas.health_check.health_check_name import HealthCheckName +from typeguard import typechecked + +# Error counters that indicate fabric health problems. +# Any non-zero value is suspicious; rapid increase is critical. +ERROR_COUNTERS: list[str] = [ + "SymbolErrorCounter", + "LinkErrorRecoveryCounter", + "LinkDownedCounter", + "PortRcvErrors", + "PortRcvRemotePhysicalErrors", + "PortRcvSwitchRelayErrors", + "PortXmitDiscards", + "PortXmitConstraintErrors", + "PortRcvConstraintErrors", + "LocalLinkIntegrityErrors", + "ExcessiveBufferOverrunErrors", + "VL15Dropped", +] + +# Throughput counters (informational, included in metrics output). +THROUGHPUT_COUNTERS: list[str] = [ + "PortXmitData", + "PortRcvData", + "PortXmitPkts", + "PortRcvPkts", +] + +# Default threshold: total error count above which we alert. +DEFAULT_WARN_THRESHOLD: int = 0 +DEFAULT_CRIT_THRESHOLD: int = 100 + +SYSFS_IB_ROOT = "/sys/class/infiniband" + + +class IBCountersCheck(CheckEnv, Protocol): + """Protocol for IB counter reads — enables test injection.""" + + def discover_ports( + self, + logger: logging.Logger, + ) -> list[tuple[str, str]]: + """Return list of (device, port) tuples found on this node.""" + ... + + def read_counter( + self, + device: str, + port: str, + counter_name: str, + logger: logging.Logger, + ) -> Optional[int]: + """Read a single counter value from sysfs.""" + ... + + +@dataclass +class IBCountersCheckImpl: + """Production implementation — reads counters from sysfs.""" + + cluster: str + type: str + log_level: str + log_folder: str + + def discover_ports( + self, + logger: logging.Logger, + ) -> list[tuple[str, str]]: + """Discover all IB device/port pairs under /sys/class/infiniband.""" + ports: list[tuple[str, str]] = [] + try: + devices = os.listdir(SYSFS_IB_ROOT) + except OSError: + logger.warning("Cannot list %s", SYSFS_IB_ROOT) + return ports + + for dev in sorted(devices): + ports_dir = os.path.join(SYSFS_IB_ROOT, dev, "ports") + try: + port_nums = os.listdir(ports_dir) + except OSError: + logger.warning("Cannot list ports for %s", dev) + continue + for p in sorted(port_nums): + counters_dir = os.path.join(ports_dir, p, "counters") + if os.path.isdir(counters_dir): + ports.append((dev, p)) + return ports + + def read_counter( + self, + device: str, + port: str, + counter_name: str, + logger: logging.Logger, + ) -> Optional[int]: + """Read a single counter value from sysfs.""" + path = os.path.join( + SYSFS_IB_ROOT, device, "ports", port, "counters", counter_name + ) + try: + with open(path, "r") as f: + return int(f.read().strip()) + except (OSError, ValueError): + logger.debug( + "Failed to read counter %s for %s/%s", counter_name, device, port + ) + return None + + +@dataclass +class PortCounters: + """Parsed counters for a single IB port.""" + + device: str + port: str + errors: dict[str, int] = field(default_factory=dict) + throughput: dict[str, int] = field(default_factory=dict) + + +def collect_port_counters( + obj: IBCountersCheck, + logger: logging.Logger, +) -> list[PortCounters]: + """Read all counters for every discovered IB port.""" + results: list[PortCounters] = [] + for device, port in obj.discover_ports(logger): + pc = PortCounters(device=device, port=port) + for counter_name in ERROR_COUNTERS: + val = obj.read_counter(device, port, counter_name, logger) + if val is not None: + pc.errors[counter_name] = val + for counter_name in THROUGHPUT_COUNTERS: + val = obj.read_counter(device, port, counter_name, logger) + if val is not None: + pc.throughput[counter_name] = val + results.append(pc) + return results + + +def process_ib_counters( + port_counters: list[PortCounters], + warn_threshold: int, + crit_threshold: int, +) -> CheckOutput: + """Evaluate collected counters against thresholds. + + Returns a CheckOutput with per-port error details and Nagios metrics. + """ + check = CheckOutput("check_ib_counters") + + if not port_counters: + check.check_status = ExitCode.WARN + check.short_out = "No IB ports discovered" + return check + + total_errors = 0 + ports_with_errors = 0 + port_details: list[str] = [] + all_metrics: list[Metric] = [] + + for pc in port_counters: + port_label = f"{pc.device}/{pc.port}" + port_error_total = sum(pc.errors.values()) + total_errors += port_error_total + + if port_error_total > 0: + ports_with_errors += 1 + nonzero = [f"{name}={val}" for name, val in pc.errors.items() if val > 0] + port_details.append(f"{port_label}: {'; '.join(nonzero)}") + + # Emit metrics for each error counter. + for name, val in pc.errors.items(): + all_metrics.append( + Metric( + name=f"{port_label}.{name}", + value=val, + metric_warn=str(warn_threshold), + metric_crit=str(crit_threshold), + ) + ) + + # Emit throughput metrics (informational). + for name, val in pc.throughput.items(): + all_metrics.append(Metric(name=f"{port_label}.{name}", value=val)) + + # Determine overall status. + if total_errors > crit_threshold: + check.check_status = ExitCode.CRITICAL + elif total_errors > warn_threshold: + check.check_status = ExitCode.WARN + else: + check.check_status = ExitCode.OK + + check.short_out = ( + f"{len(port_counters)} ports checked, " + f"{ports_with_errors} with errors, " + f"total_errors={total_errors}" + ) + check.long_out = port_details + check.short_metrics = all_metrics + return check + + +@click.command() +@common_arguments +@timeout_argument +@telemetry_argument +@heterogeneous_cluster_v1_option +@click.option( + "--warn-threshold", + type=click.INT, + default=DEFAULT_WARN_THRESHOLD, + show_default=True, + help="Total error count above which the check returns WARNING.", +) +@click.option( + "--crit-threshold", + type=click.INT, + default=DEFAULT_CRIT_THRESHOLD, + show_default=True, + help="Total error count above which the check returns CRITICAL.", +) +@click.pass_obj +@typechecked +def check_ib_counters( + obj: Optional[IBCountersCheck], + cluster: str, + type: CHECK_TYPE, + log_level: LOG_LEVEL, + log_folder: str, + timeout: int, + sink: str, + sink_opts: Collection[str], + verbose_out: bool, + heterogeneous_cluster_v1: bool, + warn_threshold: int, + crit_threshold: int, +) -> None: + """Check IB port error counters against configurable thresholds. + + Reads /sys/class/infiniband/*/ports/*/counters/ for error counters + (SymbolErrorCounter, LinkDownedCounter, PortRcvErrors, …) and + throughput counters (PortXmitData, PortRcvData, …). Alerts when the + aggregate error count exceeds the configured thresholds. + """ + if not obj: + obj = IBCountersCheckImpl(cluster, type, log_level, log_folder) + + with HealthCheckRuntime( + cluster=cluster, + check_type=type, + log_level=log_level, + log_folder=log_folder, + sink=sink, + sink_opts=sink_opts, + verbose_out=verbose_out, + heterogeneous_cluster_v1=heterogeneous_cluster_v1, + health_check_name=HealthCheckName.CHECK_IB_COUNTERS, + killswitch_getter=lambda: FeatureValueHealthChecksFeatures().get_healthchecksfeatures_disable_check_ib_counters(), + ) as rt: + try: + port_counters = collect_port_counters(obj, rt.logger) + except Exception: + rt.logger.exception("Failed to collect IB counters") + port_counters = [] + + output = process_ib_counters(port_counters, warn_threshold, crit_threshold) + rt.finish(output.check_status, str(output)) diff --git a/gcm/health_checks/cli/health_checks.py b/gcm/health_checks/cli/health_checks.py index 44b8660..ac2bffb 100644 --- a/gcm/health_checks/cli/health_checks.py +++ b/gcm/health_checks/cli/health_checks.py @@ -53,6 +53,7 @@ def health_checks(detach: bool) -> None: checks.check_blockdev, checks.check_ethlink, checks.check_sensors, + checks.check_ib_counters, ] for check in list_of_checks: diff --git a/gcm/monitoring/features/feature_definitions/health_checks_features.py b/gcm/monitoring/features/feature_definitions/health_checks_features.py index 395bdd1..4460167 100644 --- a/gcm/monitoring/features/feature_definitions/health_checks_features.py +++ b/gcm/monitoring/features/feature_definitions/health_checks_features.py @@ -63,3 +63,4 @@ class HealthChecksFeatures: disable_check_clocksource: bool disable_airstore_credential_count: bool disable_check_sensors: bool + disable_check_ib_counters: bool diff --git a/gcm/monitoring/features/gen/generated_features_healthchecksfeatures.py b/gcm/monitoring/features/gen/generated_features_healthchecksfeatures.py index 1372d5f..d6def87 100644 --- a/gcm/monitoring/features/gen/generated_features_healthchecksfeatures.py +++ b/gcm/monitoring/features/gen/generated_features_healthchecksfeatures.py @@ -898,3 +898,17 @@ def get_healthchecksfeatures_disable_check_sensors(self) -> bool: f"Expected bool value for HealthChecksFeatures.disable_check_sensors, got {type(value).__name__} instead." ) return value + + def get_healthchecksfeatures_disable_check_ib_counters(self) -> bool: + try: + features = self.load_config() + except Exception: + return False + value = features.get("HealthChecksFeatures", {}).get( + "disable_check_ib_counters", False + ) + if not isinstance(value, bool): + raise TypeError( + f"Expected bool value for HealthChecksFeatures.disable_check_ib_counters, got {type(value).__name__} instead." + ) + return value diff --git a/gcm/schemas/health_check/health_check_name.py b/gcm/schemas/health_check/health_check_name.py index a27775f..9dc7045 100644 --- a/gcm/schemas/health_check/health_check_name.py +++ b/gcm/schemas/health_check/health_check_name.py @@ -66,3 +66,4 @@ class HealthCheckName(Enum): CHECK_ETHLINK = "check ethlink" CHECK_CLOCKSOURCE = "check clocksource" CHECK_SENSORS = "check sensors" + CHECK_IB_COUNTERS = "check ib counters" diff --git a/gcm/tests/health_checks_tests/test_check_ib_counters.py b/gcm/tests/health_checks_tests/test_check_ib_counters.py new file mode 100644 index 0000000..402a2f1 --- /dev/null +++ b/gcm/tests/health_checks_tests/test_check_ib_counters.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +"""Test the check_ib_counters health-check.""" + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import pytest +from click.testing import CliRunner +from gcm.health_checks.checks.check_ib_counters import ( + check_ib_counters, + PortCounters, + process_ib_counters, +) +from gcm.health_checks.types import ExitCode + + +# --------------------------------------------------------------------------- +# Fake implementation — injected via click's obj parameter +# --------------------------------------------------------------------------- +@dataclass +class FakeIBCountersCheckImpl: + """Return pre-configured IB counter data instead of reading sysfs.""" + + ports: list[tuple[str, str]] = field(default_factory=list) + counters: dict[str, dict[str, int]] = field(default_factory=dict) + + cluster = "test cluster" + type = "prolog" + log_level = "INFO" + log_folder = "/tmp" + + def discover_ports( + self, + _logger: logging.Logger, + ) -> list[tuple[str, str]]: + """Return pre-configured port list.""" + return self.ports + + def read_counter( + self, + device: str, + port: str, + counter_name: str, + _logger: logging.Logger, + ) -> Optional[int]: + """Return pre-configured counter value.""" + key = f"{device}/{port}" + return self.counters.get(key, {}).get(counter_name) + + +# --------------------------------------------------------------------------- +# Test data +# --------------------------------------------------------------------------- +NO_PORTS = FakeIBCountersCheckImpl(ports=[], counters={}) + +CLEAN_SINGLE_PORT = FakeIBCountersCheckImpl( + ports=[("mlx5_0", "1")], + counters={ + "mlx5_0/1": { + "SymbolErrorCounter": 0, + "LinkErrorRecoveryCounter": 0, + "LinkDownedCounter": 0, + "PortRcvErrors": 0, + "PortRcvRemotePhysicalErrors": 0, + "PortRcvSwitchRelayErrors": 0, + "PortXmitDiscards": 0, + "PortXmitConstraintErrors": 0, + "PortRcvConstraintErrors": 0, + "LocalLinkIntegrityErrors": 0, + "ExcessiveBufferOverrunErrors": 0, + "VL15Dropped": 0, + "PortXmitData": 123456789, + "PortRcvData": 987654321, + "PortXmitPkts": 1000000, + "PortRcvPkts": 2000000, + }, + }, +) + +WARN_SINGLE_PORT = FakeIBCountersCheckImpl( + ports=[("mlx5_0", "1")], + counters={ + "mlx5_0/1": { + "SymbolErrorCounter": 5, + "LinkErrorRecoveryCounter": 0, + "LinkDownedCounter": 0, + "PortRcvErrors": 2, + "PortRcvRemotePhysicalErrors": 0, + "PortRcvSwitchRelayErrors": 0, + "PortXmitDiscards": 0, + "PortXmitConstraintErrors": 0, + "PortRcvConstraintErrors": 0, + "LocalLinkIntegrityErrors": 0, + "ExcessiveBufferOverrunErrors": 0, + "VL15Dropped": 0, + "PortXmitData": 100, + "PortRcvData": 200, + "PortXmitPkts": 50, + "PortRcvPkts": 60, + }, + }, +) + +CRITICAL_MULTI_PORT = FakeIBCountersCheckImpl( + ports=[("mlx5_0", "1"), ("mlx5_1", "1")], + counters={ + "mlx5_0/1": { + "SymbolErrorCounter": 50, + "LinkErrorRecoveryCounter": 10, + "LinkDownedCounter": 5, + "PortRcvErrors": 30, + "PortRcvRemotePhysicalErrors": 0, + "PortRcvSwitchRelayErrors": 0, + "PortXmitDiscards": 10, + "PortXmitConstraintErrors": 0, + "PortRcvConstraintErrors": 0, + "LocalLinkIntegrityErrors": 0, + "ExcessiveBufferOverrunErrors": 0, + "VL15Dropped": 0, + "PortXmitData": 100, + "PortRcvData": 200, + "PortXmitPkts": 50, + "PortRcvPkts": 60, + }, + "mlx5_1/1": { + "SymbolErrorCounter": 0, + "LinkErrorRecoveryCounter": 0, + "LinkDownedCounter": 0, + "PortRcvErrors": 0, + "PortRcvRemotePhysicalErrors": 0, + "PortRcvSwitchRelayErrors": 0, + "PortXmitDiscards": 0, + "PortXmitConstraintErrors": 0, + "PortRcvConstraintErrors": 0, + "LocalLinkIntegrityErrors": 0, + "ExcessiveBufferOverrunErrors": 0, + "VL15Dropped": 0, + "PortXmitData": 300, + "PortRcvData": 400, + "PortXmitPkts": 70, + "PortRcvPkts": 80, + }, + }, +) + +CLEAN_MULTI_PORT = FakeIBCountersCheckImpl( + ports=[("mlx5_0", "1"), ("mlx5_1", "1"), ("mlx5_2", "1"), ("mlx5_3", "1")], + counters={ + f"mlx5_{i}/1": { + "SymbolErrorCounter": 0, + "LinkErrorRecoveryCounter": 0, + "LinkDownedCounter": 0, + "PortRcvErrors": 0, + "PortRcvRemotePhysicalErrors": 0, + "PortRcvSwitchRelayErrors": 0, + "PortXmitDiscards": 0, + "PortXmitConstraintErrors": 0, + "PortRcvConstraintErrors": 0, + "LocalLinkIntegrityErrors": 0, + "ExcessiveBufferOverrunErrors": 0, + "VL15Dropped": 0, + "PortXmitData": 1000 * (i + 1), + "PortRcvData": 2000 * (i + 1), + "PortXmitPkts": 100 * (i + 1), + "PortRcvPkts": 200 * (i + 1), + } + for i in range(4) + }, +) + + +# --------------------------------------------------------------------------- +# Fixture +# --------------------------------------------------------------------------- +@pytest.fixture +def ib_counters_tester( + request: pytest.FixtureRequest, +) -> FakeIBCountersCheckImpl: + """Create FakeIBCountersCheckImpl object.""" + return request.param + + +# --------------------------------------------------------------------------- +# Unit tests for the pure process_ib_counters function +# --------------------------------------------------------------------------- +class TestProcessIBCounters: + """Test the pure processing logic directly without Click.""" + + def test_no_ports_returns_warn(self) -> None: + result = process_ib_counters([], warn_threshold=0, crit_threshold=100) + assert result.check_status == ExitCode.WARN + assert "No IB ports discovered" in result.short_out + + def test_clean_port_returns_ok(self) -> None: + pc = PortCounters( + device="mlx5_0", + port="1", + errors={"SymbolErrorCounter": 0, "PortRcvErrors": 0}, + throughput={"PortXmitData": 12345}, + ) + result = process_ib_counters([pc], warn_threshold=0, crit_threshold=100) + assert result.check_status == ExitCode.OK + assert "0 with errors" in result.short_out + + def test_errors_above_warn_threshold(self) -> None: + pc = PortCounters( + device="mlx5_0", + port="1", + errors={"SymbolErrorCounter": 5, "PortRcvErrors": 3}, + throughput={}, + ) + result = process_ib_counters([pc], warn_threshold=0, crit_threshold=100) + assert result.check_status == ExitCode.WARN + assert "1 with errors" in result.short_out + assert "total_errors=8" in result.short_out + + def test_errors_above_crit_threshold(self) -> None: + pc = PortCounters( + device="mlx5_0", + port="1", + errors={"SymbolErrorCounter": 80, "PortRcvErrors": 30}, + throughput={}, + ) + result = process_ib_counters([pc], warn_threshold=0, crit_threshold=100) + assert result.check_status == ExitCode.CRITICAL + assert "total_errors=110" in result.short_out + + def test_long_out_lists_nonzero_counters(self) -> None: + pc = PortCounters( + device="mlx5_0", + port="1", + errors={ + "SymbolErrorCounter": 5, + "PortRcvErrors": 0, + "LinkDownedCounter": 2, + }, + throughput={}, + ) + result = process_ib_counters([pc], warn_threshold=0, crit_threshold=100) + assert len(result.long_out) == 1 + assert "SymbolErrorCounter=5" in result.long_out[0] + assert "LinkDownedCounter=2" in result.long_out[0] + assert "PortRcvErrors" not in result.long_out[0] + + def test_multi_port_mixed_errors(self) -> None: + clean_port = PortCounters( + device="mlx5_0", + port="1", + errors={"SymbolErrorCounter": 0}, + throughput={}, + ) + bad_port = PortCounters( + device="mlx5_1", + port="1", + errors={"SymbolErrorCounter": 10}, + throughput={}, + ) + result = process_ib_counters( + [clean_port, bad_port], + warn_threshold=0, + crit_threshold=100, + ) + assert result.check_status == ExitCode.WARN + assert "2 ports checked" in result.short_out + assert "1 with errors" in result.short_out + assert len(result.long_out) == 1 + assert "mlx5_1/1" in result.long_out[0] + + def test_metrics_emitted_for_all_counters(self) -> None: + pc = PortCounters( + device="mlx5_0", + port="1", + errors={"SymbolErrorCounter": 0, "PortRcvErrors": 0}, + throughput={"PortXmitData": 999}, + ) + result = process_ib_counters([pc], warn_threshold=0, crit_threshold=100) + metric_names = [m.name for m in result.short_metrics] + assert "mlx5_0/1.SymbolErrorCounter" in metric_names + assert "mlx5_0/1.PortRcvErrors" in metric_names + assert "mlx5_0/1.PortXmitData" in metric_names + + +# --------------------------------------------------------------------------- +# Integration tests via Click runner +# --------------------------------------------------------------------------- +@pytest.mark.parametrize( + ("ib_counters_tester", "expected"), + [ + ( + NO_PORTS, + (ExitCode.WARN, "No IB ports discovered"), + ), + ( + CLEAN_SINGLE_PORT, + (ExitCode.OK, "1 ports checked, 0 with errors"), + ), + ( + WARN_SINGLE_PORT, + (ExitCode.WARN, "1 ports checked, 1 with errors"), + ), + ( + CRITICAL_MULTI_PORT, + (ExitCode.CRITICAL, "2 ports checked, 1 with errors"), + ), + ( + CLEAN_MULTI_PORT, + (ExitCode.OK, "4 ports checked, 0 with errors"), + ), + ], + indirect=["ib_counters_tester"], +) +def test_check_ib_counters( + caplog: pytest.LogCaptureFixture, + tmp_path: Path, + ib_counters_tester: FakeIBCountersCheckImpl, + expected: tuple[ExitCode, str], +) -> None: + """Invoke check_ib_counters via Click and verify exit code and output.""" + runner = CliRunner(mix_stderr=False) + caplog.at_level(logging.INFO) + + result = runner.invoke( + check_ib_counters, + f"fair_cluster prolog --log-folder={tmp_path} --sink=do_nothing", + obj=ib_counters_tester, + ) + + assert result.exit_code == expected[0].value + assert expected[1] in caplog.text + + +@pytest.mark.parametrize( + ("ib_counters_tester", "threshold_args", "expected_exit"), + [ + ( + WARN_SINGLE_PORT, + "--warn-threshold=10 --crit-threshold=100", + ExitCode.OK, + ), + ( + WARN_SINGLE_PORT, + "--warn-threshold=0 --crit-threshold=5", + ExitCode.CRITICAL, + ), + ], + indirect=["ib_counters_tester"], +) +def test_check_ib_counters_custom_thresholds( + caplog: pytest.LogCaptureFixture, + tmp_path: Path, + ib_counters_tester: FakeIBCountersCheckImpl, + threshold_args: str, + expected_exit: ExitCode, +) -> None: + """Verify that --warn-threshold and --crit-threshold are respected.""" + runner = CliRunner(mix_stderr=False) + caplog.at_level(logging.INFO) + + result = runner.invoke( + check_ib_counters, + f"fair_cluster prolog --log-folder={tmp_path} --sink=do_nothing {threshold_args}", + obj=ib_counters_tester, + ) + + assert result.exit_code == expected_exit.value