Skip to content

Commit 5d3a520

Browse files
committed
Add backend CI workflow and inference API tests
Introduces a GitHub Actions workflow for backend API CI, running tests on push and pull requests to main. Adds comprehensive tests for inference API endpoints, including health checks, property prediction, attestation-only responses, and verification error handling.
1 parent f821b64 commit 5d3a520

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

.github/workflows/ci-backend.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Backend API CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
env:
10+
CI: "true"
11+
# Force open mode for tests (auth is disabled when API_KEY is empty)
12+
API_KEY: ""
13+
14+
jobs:
15+
backend-tests:
16+
runs-on: ubuntu-latest
17+
18+
steps:
19+
- name: Checkout code
20+
uses: actions/checkout@v4
21+
22+
- name: Set up Python
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: "3.10"
26+
27+
- name: Install dependencies
28+
run: |
29+
python -m pip install --upgrade pip
30+
pip install -r requirements.txt
31+
pip install pytest pytest-cov
32+
33+
- name: Run backend API tests
34+
run: |
35+
pytest tests/backend --maxfail=1 --disable-warnings
36+
# Optional: enable coverage gating later, e.g.:
37+
# pytest tests/backend --maxfail=1 --disable-warnings \
38+
# --cov=scripts.inference_api --cov-report=term-missing --cov-fail-under=50
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Any, Dict
4+
5+
import numpy as np
6+
import pytest
7+
from fastapi.testclient import TestClient
8+
9+
from scripts import inference_api as api
10+
11+
12+
# ---------------------------------------------------------------------
13+
# Dummy pipeline + registry to avoid loading real models / Algorand
14+
# ---------------------------------------------------------------------
15+
class DummyTree:
16+
"""Simple tree-like estimator used to emulate per-tree predictions."""
17+
18+
def __init__(self, value: float) -> None:
19+
self._value = float(value)
20+
21+
def predict(self, X) -> np.ndarray: # noqa: N803
22+
# Ignore X, always return a single value
23+
return np.array([self._value], dtype=float)
24+
25+
26+
class DummyModel:
27+
"""Simple ensemble-like model with estimators_ attribute."""
28+
29+
def __init__(self) -> None:
30+
# Three slightly different predictions so that std > 0
31+
self.estimators_ = [
32+
DummyTree(190.0),
33+
DummyTree(210.0),
34+
DummyTree(200.0),
35+
]
36+
37+
def predict(self, X) -> np.ndarray: # noqa: N803
38+
return np.array([200.0], dtype=float)
39+
40+
41+
class DummyPipeline:
42+
"""
43+
Minimal pipeline-like object exposing:
44+
- .predict(...)
45+
- .estimator (so that _unwrap_final_estimator can find DummyModel)
46+
"""
47+
48+
def __init__(self) -> None:
49+
self.estimator = DummyModel()
50+
51+
def predict(self, X) -> np.ndarray: # noqa: N803
52+
return self.estimator.predict(X)
53+
54+
55+
class DummyRegistry:
56+
"""Minimal AttestationRegistry replacement for tests."""
57+
58+
def __init__(self) -> None:
59+
self.seen_calls = []
60+
self.record_calls = []
61+
62+
def seen(self, p1_sha: str, asset_id: str) -> bool:
63+
self.seen_calls.append((p1_sha, asset_id))
64+
# Never mark as replay in tests
65+
return False
66+
67+
def record(self, *args: Any, **kwargs: Any) -> None:
68+
self.record_calls.append((args, kwargs))
69+
70+
71+
# ---------------------------------------------------------------------
72+
# Pytest fixtures
73+
# ---------------------------------------------------------------------
74+
@pytest.fixture(autouse=True)
75+
def disable_auth_and_reset_rate_limit(monkeypatch):
76+
"""
77+
Ensure the API runs in "open" mode for tests and reset rate limiter.
78+
"""
79+
# Disable Bearer token requirement
80+
api.API_KEY = None # type: ignore[attr-defined]
81+
82+
# Reset rate limit bucket between tests
83+
if hasattr(api, "_rate_bucket"):
84+
api._rate_bucket.clear() # type: ignore[attr-defined]
85+
86+
yield
87+
88+
if hasattr(api, "_rate_bucket"):
89+
api._rate_bucket.clear() # type: ignore[attr-defined]
90+
91+
92+
@pytest.fixture
93+
def client(monkeypatch, tmp_path) -> TestClient:
94+
"""
95+
TestClient with all heavy dependencies mocked so that:
96+
- no real model files are required
97+
- no Algorand / network calls are performed
98+
- no real files are written outside a temp directory
99+
"""
100+
# ---- Environment / paths isolation ----
101+
monkeypatch.setenv("OUTPUTS_DIR", str(tmp_path))
102+
api.API_LOG_PATH = tmp_path / "api_inference_log.jsonl" # type: ignore[attr-defined]
103+
104+
# ---- Model registry / pipeline mocks ----
105+
dummy_meta: Dict[str, Any] = {
106+
"model_version": "1.0.0-test",
107+
"model_class": "DummyModel",
108+
"metrics": {"MAE": 10.0, "R2": 0.9},
109+
# A small but realistic feature_order
110+
"feature_order": [
111+
"location",
112+
"size_m2",
113+
"rooms",
114+
"bathrooms",
115+
"year_built",
116+
"floor",
117+
"building_floors",
118+
"has_elevator",
119+
"has_garden",
120+
"has_balcony",
121+
"has_garage",
122+
"energy_class",
123+
"humidity_level",
124+
"temperature_avg",
125+
"noise_level",
126+
"air_quality_index",
127+
"age_years",
128+
"listing_month",
129+
"city",
130+
"region",
131+
"zone",
132+
"public_transport_nearby",
133+
],
134+
}
135+
136+
def fake_get_pipeline(asset_type: str, task: str) -> DummyPipeline:
137+
return DummyPipeline()
138+
139+
def fake_get_model_paths(asset_type: str, task: str) -> Dict[str, str]:
140+
return {
141+
"pipeline": "dummy.joblib",
142+
"manifest": "",
143+
}
144+
145+
def fake_get_model_metadata(asset_type: str, task: str) -> Dict[str, Any]:
146+
return dict(dummy_meta)
147+
148+
monkeypatch.setattr(api, "get_pipeline", fake_get_pipeline)
149+
monkeypatch.setattr(api, "get_model_paths", fake_get_model_paths)
150+
monkeypatch.setattr(api, "get_model_metadata", fake_get_model_metadata)
151+
152+
# ---- Validation / explanation / pricing helpers ----
153+
def fake_validate_property(base: Dict[str, Any]) -> Dict[str, Any]:
154+
return {"ok": True, "warnings": [], "errors": []}
155+
156+
def fake_explain_price(base: Dict[str, Any]) -> Dict[str, Any]:
157+
return {"components": []}
158+
159+
def fake_price_benchmark(location: Any, valuation_k: float) -> Dict[str, Any]:
160+
return {
161+
"location": location,
162+
"valuation_k": valuation_k,
163+
"out_of_band": False,
164+
}
165+
166+
monkeypatch.setattr(api, "validate_property", fake_validate_property)
167+
monkeypatch.setattr(api, "explain_price", fake_explain_price)
168+
monkeypatch.setattr(api, "price_benchmark", fake_price_benchmark)
169+
170+
# ---- PoVal builder / canonicalization ----
171+
def fake_build_p1_from_response(response: Dict[str, Any], allowed_input_keys: Any):
172+
p1 = {
173+
"s": "p1",
174+
"v": 200.0,
175+
"u": [180.0, 220.0],
176+
"ts": 1_700_000_000,
177+
}
178+
dbg = {"ih": "fake_input_hash"}
179+
return p1, dbg
180+
181+
def fake_canonical_note_bytes_p1(p1: Dict[str, Any]):
182+
# bytes, sha256, size
183+
return b"{}", "fake_p1_sha256", 128
184+
185+
monkeypatch.setattr(api, "build_p1_from_response", fake_build_p1_from_response)
186+
monkeypatch.setattr(api, "canonical_note_bytes_p1", fake_canonical_note_bytes_p1)
187+
188+
# ---- Audit bundle / registry / network helpers ----
189+
dummy_registry = DummyRegistry()
190+
api.registry = dummy_registry # type: ignore[attr-defined]
191+
192+
def fake_save_audit_bundle(bundle_dir: Path, **kwargs: Any) -> None:
193+
# Do nothing; we just want the call to succeed
194+
return None
195+
196+
def fake_publish_ai_prediction(*args: Any, **kwargs: Any) -> Dict[str, Any]:
197+
return {
198+
"blockchain_txid": "FAKE_TXID",
199+
"asa_id": 12345,
200+
"note_size": 128,
201+
"note_sha256": "fake_p1_sha256",
202+
"is_compacted": True,
203+
"confirmed_round": 42,
204+
}
205+
206+
def fake_get_network() -> str:
207+
return "testnet"
208+
209+
def fake_get_tx_note_info(txid: str) -> Dict[str, Any]:
210+
# Only used when publish=True; here we simulate an indexer response
211+
return {
212+
"note_json": {"s": "p1"},
213+
"note_sha256": "fake_p1_sha256",
214+
"note_size": 128,
215+
"confirmed_round": 42,
216+
"explorer_url": f"https://fake.explorer/tx/{txid}",
217+
}
218+
219+
monkeypatch.setattr(api, "save_audit_bundle", fake_save_audit_bundle)
220+
monkeypatch.setattr(api, "publish_ai_prediction", fake_publish_ai_prediction)
221+
monkeypatch.setattr(api, "get_network", fake_get_network)
222+
monkeypatch.setattr(api, "get_tx_note_info", fake_get_tx_note_info)
223+
224+
# Health / cache helpers (used by /health)
225+
def fake_health_check_model(asset_type: str, task: str) -> Dict[str, Any]:
226+
return {"status": "healthy", "asset_type": asset_type, "task": task}
227+
228+
def fake_cache_stats() -> Dict[str, Any]:
229+
return {"hits": 0, "misses": 0}
230+
231+
monkeypatch.setattr(api, "health_check_model", fake_health_check_model)
232+
monkeypatch.setattr(api, "cache_stats", fake_cache_stats)
233+
234+
return TestClient(api.app)
235+
236+
237+
# ---------------------------------------------------------------------
238+
# Tests
239+
# ---------------------------------------------------------------------
240+
def test_health_endpoint_ok(client: TestClient):
241+
"""Basic sanity check for /health."""
242+
resp = client.get("/health")
243+
assert resp.status_code == 200
244+
data = resp.json()
245+
246+
assert data["status"] in ("ok", "degraded")
247+
assert "model_health" in data
248+
assert "cache_stats" in data
249+
assert "asset_types" in data
250+
assert "schema_version" in data
251+
assert "api_version" in data
252+
253+
254+
def test_predict_property_happy_path(client: TestClient):
255+
"""
256+
/predict/property with minimal payload should return:
257+
- HTTP 200
258+
- valuation metrics
259+
- attestation info (p1 + sha)
260+
- audit bundle id
261+
"""
262+
payload = {
263+
"location": "Milan",
264+
"size_m2": 80,
265+
"rooms": 3,
266+
"bathrooms": 2,
267+
"year_built": 2005,
268+
"floor": 2,
269+
"building_floors": 5,
270+
"has_elevator": 1,
271+
}
272+
273+
resp = client.post("/predict/property?publish=false&attestation_only=false", json=payload)
274+
assert resp.status_code == 200
275+
276+
data = resp.json()
277+
assert data["asset_type"] == "property"
278+
assert "asset_id" in data
279+
assert "metrics" in data
280+
assert "attestation" in data
281+
assert "audit_bundle" in data
282+
283+
metrics = data["metrics"]
284+
assert "valuation_k" in metrics
285+
assert "confidence_low_k" in metrics
286+
assert "confidence_high_k" in metrics
287+
assert metrics["valuation_k"] > 0
288+
289+
290+
def test_predict_property_attestation_only(client: TestClient):
291+
"""
292+
/predict/property with attestation_only=true should return
293+
a compact structure focused on the attestation.
294+
"""
295+
payload = {
296+
"location": "Rome",
297+
"size_m2": 60,
298+
"rooms": 2,
299+
"bathrooms": 1,
300+
"year_built": 2010,
301+
}
302+
303+
resp = client.post("/predict/property?publish=false&attestation_only=true", json=payload)
304+
assert resp.status_code == 200
305+
306+
data = resp.json()
307+
# Compact response: should NOT contain full metrics
308+
assert "asset_id" in data
309+
assert "attestation" in data
310+
assert "attestation_sha256" in data
311+
assert "attestation_size" in data
312+
assert "txid" in data
313+
assert data.get("published") is False
314+
assert "metrics" not in data # ensure we really returned the compact form
315+
316+
317+
def test_verify_missing_txid_returns_422(client: TestClient):
318+
"""POST /verify without 'txid' should return 422."""
319+
resp = client.post("/verify", json={})
320+
assert resp.status_code == 422

0 commit comments

Comments
 (0)