Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions hindsight-api-slim/hindsight_api/worker/poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .stage import StageHolder, bind_holder

if TYPE_CHECKING:
from hindsight_api.engine.db.base import DatabaseBackend
from hindsight_api.engine.db.base import DatabaseBackend, DatabaseConnection
from hindsight_api.extensions.tenant import TenantExtension

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -187,13 +187,18 @@ def __init__(
# schema we serviced so a busy tenant can't monopolize the poll order.
self._next_schema_idx: int = 0

async def _get_schemas(self) -> list[str | None]:
"""Get list of schemas to poll. Returns [None] for default schema (no prefix)."""
@staticmethod
def _normalize_poll_schema(schema: str | None) -> str | None:
"""Use None internally for the default schema because SQL helpers omit that prefix."""
from ..config import DEFAULT_DATABASE_SCHEMA

return None if schema == DEFAULT_DATABASE_SCHEMA else schema

async def _get_schemas(self) -> list[str | None]:
"""Get list of schemas to poll. Returns [None] for default schema (no prefix)."""
tenants = await self._tenant_extension.list_tenants()
# Convert default schema to None for SQL compatibility (no prefix), keep others as-is
return [t.schema if t.schema != DEFAULT_DATABASE_SCHEMA else None for t in tenants]
return [self._normalize_poll_schema(t.schema) for t in tenants]

async def _scan_active_schemas(self, schemas: list[str | None]) -> set[str | None]:
"""Find which schemas have pending work.
Expand All @@ -213,22 +218,57 @@ async def _scan_active_schemas(self, schemas: list[str | None]) -> set[str | Non
async with self._backend.acquire() as conn:
if await self._optional_routines.is_installed(conn, "schemas_with_pending_work"):
rows = await conn.fetch("SELECT * FROM public.schemas_with_pending_work()")
return {r[0] for r in rows}
routine_active = {self._normalize_poll_schema(r[0]) for r in rows}
known_schemas = set(schemas)
active = routine_active & known_schemas
unknown = routine_active - known_schemas
if unknown:
logger.warning(
"Optional PG routine public.schemas_with_pending_work() returned schema(s) "
"not present in tenant discovery: %s",
sorted(str(s) for s in unknown),
)

# Fallback: per-schema EXISTS checks from Python
active: set[str | None] = set()
for schema in schemas:
table = fq_table("async_operations", schema)
try:
has_work = await conn.fetchval(
f"SELECT EXISTS(SELECT 1 FROM {table} "
f"WHERE status = 'pending' AND task_payload IS NOT NULL LIMIT 1)"
# The optional routine returns PostgreSQL schema names, but the poller uses
# None for the default schema. Older operator-supplied implementations also
# commonly scan tenant_% only; when the default schema is in scope but absent
# from the routine result, verify via the fully-correct per-schema fallback so
# public single-tenant deployments cannot silently starve.
should_verify_with_fallback = (None in known_schemas and None not in active) or (
bool(routine_active) and not active
)
if not should_verify_with_fallback:
return active

fallback_active = await self._scan_active_schemas_by_exists(conn, schemas)
missed = fallback_active - active
if missed:
logger.warning(
"Optional PG routine public.schemas_with_pending_work() missed claimable schema(s) %s; "
"using per-schema fallback for this poll",
sorted(str(s) for s in missed),
)
if has_work:
active.add(schema)
except Exception:
pass
return active
return fallback_active

return await self._scan_active_schemas_by_exists(conn, schemas)

async def _scan_active_schemas_by_exists(
self, conn: "DatabaseConnection", schemas: list[str | None]
) -> set[str | None]:
"""Find active schemas using per-schema EXISTS checks."""
active: set[str | None] = set()
for schema in schemas:
table = fq_table("async_operations", schema)
try:
has_work = await conn.fetchval(
f"SELECT EXISTS(SELECT 1 FROM {table} "
f"WHERE status = 'pending' AND task_payload IS NOT NULL LIMIT 1)"
)
if has_work:
active.add(schema)
except Exception:
pass
return active

async def _get_available_slots(self) -> SlotAvailability:
"""
Expand Down
79 changes: 73 additions & 6 deletions hindsight-api-slim/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,9 @@ async def test_claim_batch_skips_consolidation_when_same_bank_processing(self, p
assert row["worker_id"] is None

@pytest.mark.asyncio
async def test_claim_batch_allows_non_consolidation_when_consolidation_processing(self, pool, backend, clean_operations):
async def test_claim_batch_allows_non_consolidation_when_consolidation_processing(
self, pool, backend, clean_operations
):
"""Test that non-consolidation tasks are still claimed even if consolidation is processing."""
from hindsight_api.worker import WorkerPoller

Expand Down Expand Up @@ -2554,8 +2556,7 @@ async def test_mark_failed_finalises_parent_when_last_sibling_fails(self, pool,
# downstream filters that classify failures by error_message lose all
# signal once a batch has children.
assert "DB constraint violation" in (parent_row["error_message"] or ""), (
f"Parent error_message should inherit child's reason, "
f"got: {parent_row['error_message']!r}"
f"Parent error_message should inherit child's reason, got: {parent_row['error_message']!r}"
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -2586,7 +2587,9 @@ async def test_mark_failed_finalises_parent_when_last_sibling_is_sole_child(self
assert parent_row["status"] == "failed"

@pytest.mark.asyncio
async def test_mark_failed_does_not_finalise_parent_when_siblings_still_pending(self, pool, backend, clean_operations):
async def test_mark_failed_does_not_finalise_parent_when_siblings_still_pending(
self, pool, backend, clean_operations
):
"""Parent is NOT updated while other siblings are still processing/pending."""
from hindsight_api.worker import WorkerPoller

Expand Down Expand Up @@ -2892,7 +2895,8 @@ async def test_scan_uses_optional_routine_when_installed(self, pool, backend, cl
from hindsight_api.worker import WorkerPoller

# Minimal contract-satisfying implementation: returns the empty
# set. Enough to prove the poller follows the server-side path.
# set. Use a non-default schema so the default-schema consistency
# check does not intentionally fall back to the per-schema scan.
await pool.execute(
"CREATE OR REPLACE FUNCTION public.schemas_with_pending_work() "
"RETURNS SETOF text AS $$ BEGIN RETURN; END $$ LANGUAGE plpgsql STABLE"
Expand Down Expand Up @@ -2920,7 +2924,7 @@ async def spy_fetchval(self, query, *args, column=0, timeout=None):
PostgresConnection.fetch = spy_fetch # type: ignore[method-assign]
PostgresConnection.fetchval = spy_fetchval # type: ignore[method-assign]
try:
await poller._scan_active_schemas([None])
await poller._scan_active_schemas(["tenant_alpha"])
finally:
PostgresConnection.fetch = original_fetch # type: ignore[method-assign]
PostgresConnection.fetchval = original_fetchval # type: ignore[method-assign]
Expand All @@ -2941,6 +2945,69 @@ async def spy_fetchval(self, query, *args, column=0, timeout=None):
finally:
await pool.execute("DROP FUNCTION IF EXISTS public.schemas_with_pending_work()")

@pytest.mark.asyncio
async def test_scan_normalizes_public_from_optional_routine(self, pool, backend, clean_operations):
"""The optional routine returns PostgreSQL schema names, while the
poller represents the default schema as None. ``public`` must match
[None] or claim_batch filters the active schema out.
"""
from hindsight_api.worker import WorkerPoller

await pool.execute(
"CREATE OR REPLACE FUNCTION public.schemas_with_pending_work() "
"RETURNS SETOF text AS $$ "
"BEGIN RETURN QUERY SELECT 'public'::text; END "
"$$ LANGUAGE plpgsql STABLE"
)
try:
poller = WorkerPoller(
backend=backend,
worker_id="test-routine-public",
executor=lambda x: None,
)

result = await poller._scan_active_schemas([None])

assert result == {None}
finally:
await pool.execute("DROP FUNCTION IF EXISTS public.schemas_with_pending_work()")

@pytest.mark.asyncio
async def test_scan_falls_back_when_optional_routine_misses_public(self, pool, backend, clean_operations, caplog):
"""If an installed routine scans tenant schemas only, public
single-tenant workers must still use the correct per-schema fallback.
"""
from hindsight_api.worker import WorkerPoller

bank_id = f"test-worker-missed-{uuid.uuid4().hex[:8]}"
await _ensure_bank(pool, bank_id)
await pool.execute(
"""INSERT INTO async_operations
(operation_id, bank_id, operation_type, status, task_payload)
VALUES ($1, $2, 'test', 'pending', $3::jsonb)""",
uuid.uuid4(),
bank_id,
json.dumps({"type": "test", "bank_id": bank_id}),
)
await pool.execute(
"CREATE OR REPLACE FUNCTION public.schemas_with_pending_work() "
"RETURNS SETOF text AS $$ BEGIN RETURN; END $$ LANGUAGE plpgsql STABLE"
)
try:
poller = WorkerPoller(
backend=backend,
worker_id="test-routine-missed-public",
executor=lambda x: None,
)

with caplog.at_level("WARNING", logger="hindsight_api.worker.poller"):
result = await poller._scan_active_schemas([None])

assert None in result
assert any("missed claimable schema" in record.message for record in caplog.records)
finally:
await pool.execute("DROP FUNCTION IF EXISTS public.schemas_with_pending_work()")

@pytest.mark.asyncio
async def test_claim_batch_only_queries_active_schemas(self, pool, backend, clean_operations):
"""claim_batch uses _scan_active_schemas to pre-filter, then
Expand Down