diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 3f6d44cab..493554822 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -4,6 +4,7 @@ import time import eth_utils import logging +from contextlib import asynccontextmanager from functools import partial, wraps from typing import Any from backend.protocol_rpc.exceptions import ( @@ -63,11 +64,13 @@ from backend.node.base import Manager as GenVMManager import asyncio -# Limit concurrent GenVM executions on the jsonrpc path to prevent uvloop fd conflicts. -# Workers use asyncio.Semaphore(8) in consensus/base.py; gen_call had none, allowing -# unbounded concurrent GenVM socket operations that cause fd registry collisions. +# Limit concurrent GenVM executions on the jsonrpc path to prevent uvloop fd +# conflicts and DB pool exhaustion while calls hold request-scoped sessions. +# Workers use asyncio.Semaphore(8) in consensus/base.py; keep the RPC path +# bounded too. _GENVM_CONCURRENCY = int(os.environ.get("GENVM_MAX_CONCURRENT", "8")) _genvm_semaphore = asyncio.Semaphore(_GENVM_CONCURRENCY) +_genvm_admission_semaphore = asyncio.Semaphore(_GENVM_CONCURRENCY) # --------------------------------------------------------------------------- # Per-address rate limiting for gen_call / sim_call @@ -97,13 +100,36 @@ def _check_rate_limit(address: str) -> None: ) raise JSONRPCError( code=-32005, - message=f"Rate limit exceeded: max {_RATE_LIMIT_MAX} gen_call requests per {_RATE_LIMIT_WINDOW}s per contract address", + message=f"Rate limit exceeded: max {_RATE_LIMIT_MAX} gen_call/sim_call requests per {_RATE_LIMIT_WINDOW}s per contract address", data={"address": address, "retry_after_seconds": _RATE_LIMIT_WINDOW}, ) timestamps.append(now) _address_request_log[address] = timestamps +@asynccontextmanager +async def _admit_genvm_call(method: str, to_address: str | None): + """Reject GenVM-backed RPC calls instead of queueing unlimited work.""" + if _genvm_admission_semaphore.locked(): + _rate_limit_logger.warning( + "GenVM at capacity (%s concurrent) - rejecting %s to %s", + _GENVM_CONCURRENCY, + method, + to_address, + ) + raise JSONRPCError( + code=-32006, + message=f"Server busy: all {_GENVM_CONCURRENCY} execution slots occupied, retry later", + data={"retry_after_seconds": 2}, + ) + + await _genvm_admission_semaphore.acquire() + try: + yield + finally: + _genvm_admission_semaphore.release() + + # --------------------------------------------------------------------------- # Admission control on PENDING queue depth (eth_sendRawTransaction path). # @@ -1153,15 +1179,17 @@ async def gen_call( genvm_manager: GenVMManager, params: dict, ) -> str: - receipt = await _execute_call_with_snapshot( - session, - accounts_manager, - msg_handler, - transactions_parser, - validators_manager, - genvm_manager, - params, - ) + to_address = params.get("to") if isinstance(params, dict) else None + async with _admit_genvm_call("gen_call", to_address): + receipt = await _execute_call_with_snapshot( + session, + accounts_manager, + msg_handler, + transactions_parser, + validators_manager, + genvm_manager, + params, + ) return eth_utils.hexadecimal.encode_hex(receipt.result[1:])[2:] @@ -1190,15 +1218,17 @@ async def sim_call( genvm_manager: GenVMManager, params: dict, ) -> dict: - receipt = await _execute_call_with_snapshot( - session, - accounts_manager, - msg_handler, - transactions_parser, - validators_manager, - genvm_manager, - params, - ) + to_address = params.get("to") if isinstance(params, dict) else None + async with _admit_genvm_call("sim_call", to_address): + receipt = await _execute_call_with_snapshot( + session, + accounts_manager, + msg_handler, + transactions_parser, + validators_manager, + genvm_manager, + params, + ) return receipt.to_dict() @@ -1469,45 +1499,45 @@ async def eth_call( if not accounts_manager.is_valid_address(from_address): raise InvalidAddressError(from_address) - decoded_data = transactions_parser.decode_method_call_data(data) + async with _admit_genvm_call("eth_call", to_address): + decoded_data = transactions_parser.decode_method_call_data(data) - async with validators_manager.snapshot() as snapshot: - print(snapshot.nodes) - if len(snapshot.nodes) == 0: - raise JSONRPCError( - code=-32000, - message="No validators available to execute eth_call", - data={"reason": "no_validators"}, - ) - as_validator = snapshot.nodes[0].validator - try: - target_contract_snapshot = ContractSnapshot(to_address, session) - except ContractNotFoundError: - raise NotFoundError( - message=f"Contract {to_address} not found", - data={"contract_address": to_address}, + async with validators_manager.snapshot() as snapshot: + if len(snapshot.nodes) == 0: + raise JSONRPCError( + code=-32000, + message="No validators available to execute eth_call", + data={"reason": "no_validators"}, + ) + as_validator = snapshot.nodes[0].validator + try: + target_contract_snapshot = ContractSnapshot(to_address, session) + except ContractNotFoundError: + raise NotFoundError( + message=f"Contract {to_address} not found", + data={"contract_address": to_address}, + ) + node = Node( # Mock node just to get the data from the GenVM + contract_snapshot=target_contract_snapshot, + contract_snapshot_factory=partial(ContractSnapshot, session=session), + validator_mode=ExecutionMode.LEADER, + validator=as_validator, + leader_receipt=None, + msg_handler=msg_handler.with_client_session(get_client_session_id()), + validators_snapshot=snapshot, + manager=genvm_manager, ) - node = Node( # Mock node just to get the data from the GenVM - contract_snapshot=target_contract_snapshot, - contract_snapshot_factory=partial(ContractSnapshot, session=session), - validator_mode=ExecutionMode.LEADER, - validator=as_validator, - leader_receipt=None, - msg_handler=msg_handler.with_client_session(get_client_session_id()), - validators_snapshot=snapshot, - manager=genvm_manager, - ) - try: - receipt = await node.get_contract_data( - from_address=as_validator.address, - calldata=decoded_data.calldata, - ) - except ContractNotFoundError as e: - raise NotFoundError( - message=f"Contract {e.address} not found", - data={"contract_address": e.address}, - ) from e + try: + receipt = await node.get_contract_data( + from_address=as_validator.address, + calldata=decoded_data.calldata, + ) + except ContractNotFoundError as e: + raise NotFoundError( + message=f"Contract {e.address} not found", + data={"contract_address": e.address}, + ) from e if receipt.execution_result != ExecutionResultStatus.SUCCESS: raise JSONRPCError( diff --git a/backend/protocol_rpc/health.py b/backend/protocol_rpc/health.py index e184190d4..e28370b98 100644 --- a/backend/protocol_rpc/health.py +++ b/backend/protocol_rpc/health.py @@ -214,6 +214,7 @@ def _evaluate_permit_readiness( # Send system health metrics every 6 health checks (6 × 10s = 60s = 1 minute) METRICS_SEND_INTERVAL = 6 +_no_progress_scan_suppressed_until: float = 0.0 def get_health_check_interval() -> float: @@ -221,6 +222,11 @@ def get_health_check_interval() -> float: return float(os.getenv("HEALTH_CHECK_INTERVAL_SECONDS", "10")) +def get_no_progress_scan_error_cooldown_seconds() -> float: + """Cooldown after the expensive no-progress scan times out.""" + return float(os.getenv("HEALTH_NO_PROGRESS_SCAN_ERROR_COOLDOWN_SECONDS", "300")) + + def _update_genvm_health_cache( services: Dict[str, Any], genvm_ok: bool, @@ -328,6 +334,9 @@ async def _run_health_checks() -> None: "no_progress_check_error": consensus_health.get( "no_progress_check_error", False ), + "no_progress_scan_suppressed": consensus_health.get( + "no_progress_scan_suppressed", False + ), "active_workers": consensus_health.get("active_workers", 0), "status": consensus_status, } @@ -614,6 +623,9 @@ async def _check_consensus_health() -> Dict[str, Any]: NO_PROGRESS_QUERY_TIMEOUT_MS = int( os.environ.get("HEALTH_NO_PROGRESS_QUERY_TIMEOUT_MS", "5000") ) + NO_PROGRESS_SCAN_ERROR_COOLDOWN_SECONDS = ( + get_no_progress_scan_error_cooldown_seconds() + ) RECOVERY_STORM_MIN_RECOVERIES = int( os.environ.get("HEALTH_RECOVERY_STORM_MIN_RECOVERIES", "2") ) @@ -651,6 +663,8 @@ async def _check_consensus_health() -> Dict[str, Any]: db_manager = get_database_manager() def _query_consensus(): + global _no_progress_scan_suppressed_until + from sqlalchemy import text with db_manager.engine.connect() as conn: @@ -851,6 +865,7 @@ def _query_consensus(): no_progress_window_seconds = NO_PROGRESS_WINDOW_MINUTES * 60 no_progress_check_error = False + no_progress_scan_suppressed = False # No-progress detector: first do a cheap backlog gate. The # progress scan has to inspect JSON history and can be expensive @@ -887,7 +902,13 @@ def _query_consensus(): seconds_since_consensus_progress = None last_progress_epoch = 0 - if should_scan_progress: + if ( + should_scan_progress + and time.time() < _no_progress_scan_suppressed_until + ): + no_progress_check_error = True + no_progress_scan_suppressed = True + elif should_scan_progress: try: conn.execute( text( @@ -943,12 +964,18 @@ def _query_consensus(): if last_progress_epoch else None ) + _no_progress_scan_suppressed_until = 0.0 except Exception as exc: no_progress_check_error = True + _no_progress_scan_suppressed_until = ( + time.time() + NO_PROGRESS_SCAN_ERROR_COOLDOWN_SECONDS + ) logger.warning( "No-progress health query skipped after timeout/error: %s", exc, ) + else: + _no_progress_scan_suppressed_until = 0.0 # The progress scan is an alert-quality check, not a liveness # requirement. If it times out on a large table, surface that @@ -1005,6 +1032,7 @@ def _query_consensus(): seconds_since_consensus_progress ), "no_progress_check_error": no_progress_check_error, + "no_progress_scan_suppressed": no_progress_scan_suppressed, "no_progress_window_minutes": NO_PROGRESS_WINDOW_MINUTES, "active_workers": active_workers_count, } diff --git a/tests/db-sqlalchemy/test_health_orphan_detection.py b/tests/db-sqlalchemy/test_health_orphan_detection.py index c4c6aea5e..044d202f3 100644 --- a/tests/db-sqlalchemy/test_health_orphan_detection.py +++ b/tests/db-sqlalchemy/test_health_orphan_detection.py @@ -50,14 +50,17 @@ def _wire_health_module_to_test_engine(engine: Engine): prev_mgr = session_factory._db_manager prev_router = health_module._rpc_router_ref + prev_scan_suppressed_until = health_module._no_progress_scan_suppressed_until session_factory._db_manager = _StubManager(engine=engine) health_module._rpc_router_ref = object() # truthy + health_module._no_progress_scan_suppressed_until = 0.0 yield session_factory._db_manager = prev_mgr health_module._rpc_router_ref = prev_router + health_module._no_progress_scan_suppressed_until = prev_scan_suppressed_until def _insert_tx( @@ -780,6 +783,60 @@ def fail_on_progress_scan( assert result["status"] == "healthy" +@pytest.mark.asyncio +async def test_no_progress_scan_timeout_enters_cooldown( + engine: Engine, monkeypatch: pytest.MonkeyPatch +): + """After a timeout, skip the optional JSON scan briefly instead of + re-running the same expensive query every health tick.""" + monkeypatch.setenv("HEALTH_NO_PROGRESS_WINDOW_MINUTES", "10") + monkeypatch.setenv("HEALTH_NO_PROGRESS_MIN_BACKLOG", "3") + monkeypatch.setenv("HEALTH_NO_PROGRESS_SCAN_ERROR_COOLDOWN_SECONDS", "300") + + Session_ = sessionmaker(bind=engine, expire_on_commit=False) + now = datetime.now(timezone.utc) + + with Session_() as s: + for i in range(3): + _insert_tx( + s, + tx_hash=f"0xe0{i:062x}", + to_address="0x" + "e0" * 20, + status="PENDING", + nonce=i, + created_at=now - timedelta(minutes=20 + i), + ) + s.commit() + + progress_scan_count = 0 + + def fail_first_progress_scan( + conn, cursor, statement, parameters, context, executemany + ): + nonlocal progress_scan_count + if "current_monitoring" not in statement: + return + progress_scan_count += 1 + if progress_scan_count == 1: + raise RuntimeError("simulated progress scan timeout") + raise AssertionError("progress scan should be suppressed during cooldown") + + event.listen(engine, "before_cursor_execute", fail_first_progress_scan) + try: + from backend.protocol_rpc import health as health_module + + first = await health_module._check_consensus_health() + second = await health_module._check_consensus_health() + finally: + event.remove(engine, "before_cursor_execute", fail_first_progress_scan) + + assert first["no_progress_check_error"] is True + assert first["no_progress_scan_suppressed"] is False + assert second["no_progress_check_error"] is True + assert second["no_progress_scan_suppressed"] is True + assert progress_scan_count == 1 + + @pytest.mark.asyncio async def test_recent_consensus_progress_suppresses_no_progress_alert( engine: Engine, monkeypatch: pytest.MonkeyPatch diff --git a/tests/unit/test_rpc_genvm_admission.py b/tests/unit/test_rpc_genvm_admission.py new file mode 100644 index 000000000..182c142e9 --- /dev/null +++ b/tests/unit/test_rpc_genvm_admission.py @@ -0,0 +1,179 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.protocol_rpc import endpoints +from backend.protocol_rpc.exceptions import JSONRPCError +from backend.node.types import ExecutionResultStatus + + +class _AsyncSnapshot: + def __init__(self, snapshot): + self.snapshot = snapshot + + async def __aenter__(self): + return self.snapshot + + async def __aexit__(self, exc_type, exc, traceback): + return False + + +@pytest.mark.asyncio +async def test_genvm_admission_rejects_when_slots_full(monkeypatch): + monkeypatch.setattr(endpoints, "_GENVM_CONCURRENCY", 1) + monkeypatch.setattr(endpoints, "_genvm_admission_semaphore", asyncio.Semaphore(0)) + + with pytest.raises(JSONRPCError) as exc_info: + async with endpoints._admit_genvm_call("eth_call", "0xabc"): + pass + + assert exc_info.value.code == -32006 + assert exc_info.value.data["retry_after_seconds"] == 2 + + +@pytest.mark.asyncio +async def test_genvm_admission_releases_slot_after_error(monkeypatch): + semaphore = asyncio.Semaphore(1) + monkeypatch.setattr(endpoints, "_genvm_admission_semaphore", semaphore) + + with pytest.raises(RuntimeError): + async with endpoints._admit_genvm_call("eth_call", "0xabc"): + raise RuntimeError("boom") + + assert semaphore._value == 1 + + +@pytest.mark.asyncio +async def test_eth_call_rejects_before_db_snapshot_when_genvm_full(monkeypatch): + monkeypatch.setattr(endpoints, "_GENVM_CONCURRENCY", 1) + monkeypatch.setattr(endpoints, "_genvm_admission_semaphore", asyncio.Semaphore(0)) + monkeypatch.setattr( + endpoints, "handle_consensus_data_call", lambda *args, **kwargs: None + ) + + accounts_manager = MagicMock() + accounts_manager.is_valid_address.return_value = True + + params = { + "to": "0x" + "ab" * 20, + "from": "0x" + "cd" * 20, + "data": "0x1234", + } + + with patch("backend.protocol_rpc.endpoints.ContractSnapshot") as snapshot_cls: + with pytest.raises(JSONRPCError) as exc_info: + await endpoints.eth_call( + session=MagicMock(), + accounts_manager=accounts_manager, + msg_handler=MagicMock(), + transactions_parser=MagicMock(), + validators_manager=MagicMock(), + genvm_manager=MagicMock(), + transactions_processor=MagicMock(), + params=params, + ) + + assert exc_info.value.code == -32006 + snapshot_cls.assert_not_called() + + +@pytest.mark.asyncio +async def test_gen_call_rejects_before_validator_snapshot_when_genvm_full(monkeypatch): + monkeypatch.setattr(endpoints, "_GENVM_CONCURRENCY", 1) + monkeypatch.setattr(endpoints, "_genvm_admission_semaphore", asyncio.Semaphore(0)) + + validators_manager = MagicMock() + + with pytest.raises(JSONRPCError) as exc_info: + await endpoints.gen_call( + session=MagicMock(), + accounts_manager=MagicMock(), + msg_handler=MagicMock(), + transactions_parser=MagicMock(), + validators_manager=validators_manager, + genvm_manager=MagicMock(), + params={"to": "0x" + "ab" * 20}, + ) + + assert exc_info.value.code == -32006 + validators_manager.snapshot.assert_not_called() + + +@pytest.mark.asyncio +async def test_sim_call_rejects_before_validator_snapshot_when_genvm_full(monkeypatch): + monkeypatch.setattr(endpoints, "_GENVM_CONCURRENCY", 1) + monkeypatch.setattr(endpoints, "_genvm_admission_semaphore", asyncio.Semaphore(0)) + + validators_manager = MagicMock() + + with pytest.raises(JSONRPCError) as exc_info: + await endpoints.sim_call( + session=MagicMock(), + accounts_manager=MagicMock(), + msg_handler=MagicMock(), + transactions_parser=MagicMock(), + validators_manager=validators_manager, + genvm_manager=MagicMock(), + params={"to": "0x" + "ab" * 20}, + ) + + assert exc_info.value.code == -32006 + validators_manager.snapshot.assert_not_called() + + +@pytest.mark.asyncio +async def test_eth_call_releases_admission_slot_after_success(monkeypatch): + semaphore = asyncio.Semaphore(1) + monkeypatch.setattr(endpoints, "_genvm_admission_semaphore", semaphore) + monkeypatch.setattr( + endpoints, "handle_consensus_data_call", lambda *args, **kwargs: None + ) + + accounts_manager = MagicMock() + accounts_manager.is_valid_address.return_value = True + + decoded_data = MagicMock(calldata=b"\x12\x34") + transactions_parser = MagicMock() + transactions_parser.decode_method_call_data.return_value = decoded_data + + validator = MagicMock(address="0xvalidator") + snapshot = MagicMock(nodes=[MagicMock(validator=validator)]) + validators_manager = MagicMock() + validators_manager.snapshot.return_value = _AsyncSnapshot(snapshot) + + receipt = MagicMock( + execution_result=ExecutionResultStatus.SUCCESS, + result=b"\x00\x12\x34", + ) + node = MagicMock() + node.get_contract_data = AsyncMock(return_value=receipt) + + msg_handler = MagicMock() + msg_handler.with_client_session.return_value = MagicMock() + + params = { + "to": "0x" + "ab" * 20, + "from": "0x" + "cd" * 20, + "data": "0x1234", + } + + with patch("backend.protocol_rpc.endpoints.ContractSnapshot"): + with patch("backend.protocol_rpc.endpoints.Node", return_value=node): + result = await endpoints.eth_call( + session=MagicMock(), + accounts_manager=accounts_manager, + msg_handler=msg_handler, + transactions_parser=transactions_parser, + validators_manager=validators_manager, + genvm_manager=MagicMock(), + transactions_processor=MagicMock(), + params=params, + ) + + assert result == "0x1234" + assert semaphore._value == 1 + node.get_contract_data.assert_awaited_once_with( + from_address=validator.address, + calldata=decoded_data.calldata, + )