From 9f5081807656cf7128eedba7d59ba2acd4335536 Mon Sep 17 00:00:00 2001 From: NicoleGomes Date: Thu, 11 Jun 2026 12:41:06 +0100 Subject: [PATCH 1/2] feat(backend): add formula-based LiDAR weight prediction baseline --- backend/README.md | 26 +++ backend/api/v1/predictions/__init__.py | 3 + .../api/v1/predictions/prediction_routes.py | 73 +++++++++ .../docs/weight_estimation_endpoint_notes.md | 31 ++++ backend/main.py | 2 + backend/prediction/__init__.py | 18 +++ backend/prediction/confidence.py | 56 +++++++ backend/prediction/feature_builder.py | 33 ++++ backend/prediction/formula_baseline.py | 28 ++++ backend/prediction/measurement_validator.py | 148 ++++++++++++++++++ backend/prediction/model_registry.py | 45 ++++++ backend/prediction/predictor.py | 83 ++++++++++ backend/prediction/schemas.py | 39 +++++ backend/tests/test_weight_estimation_agent.py | 108 +++++++++++++ backend/tests/test_weight_estimation_api.py | 105 +++++++++++++ 15 files changed, 798 insertions(+) create mode 100644 backend/api/v1/predictions/__init__.py create mode 100644 backend/api/v1/predictions/prediction_routes.py create mode 100644 backend/docs/weight_estimation_endpoint_notes.md create mode 100644 backend/prediction/__init__.py create mode 100644 backend/prediction/confidence.py create mode 100644 backend/prediction/feature_builder.py create mode 100644 backend/prediction/formula_baseline.py create mode 100644 backend/prediction/measurement_validator.py create mode 100644 backend/prediction/model_registry.py create mode 100644 backend/prediction/predictor.py create mode 100644 backend/prediction/schemas.py create mode 100644 backend/tests/test_weight_estimation_agent.py create mode 100644 backend/tests/test_weight_estimation_api.py diff --git a/backend/README.md b/backend/README.md index c1e08db..ade9d98 100644 --- a/backend/README.md +++ b/backend/README.md @@ -121,6 +121,32 @@ Multipart form upload: | `animal_id` | string | no | `DEMO-001` | | `breed` | string | no | `default` | +### Predictions + +- `POST /api/v1/predictions/weight-estimation` + +JSON request example: + +```bash +curl -X POST http://localhost:8000/api/v1/predictions/weight-estimation \ + -H "Content-Type: application/json" \ + -d '{ + "species": "cattle", + "breed": "minhota", + "sex": "female", + "age_months": 28, + "measurements": { + "body_length_cm": 152.4, + "withers_height_cm": 126.8, + "thoracic_depth_cm": 65.9, + "rump_width_cm": 50.2, + "chest_girth_cm": 194.3 + } + }' +``` + +This endpoint does not persist any data. It validates morphometric measurements and returns a formula-based approximate weight estimate with conservative diagnostics. + ### Organizations - `POST /api/v1/organizations` diff --git a/backend/api/v1/predictions/__init__.py b/backend/api/v1/predictions/__init__.py new file mode 100644 index 0000000..80cbbb3 --- /dev/null +++ b/backend/api/v1/predictions/__init__.py @@ -0,0 +1,3 @@ +from api.v1.predictions.prediction_routes import predictions_router + +__all__ = ["predictions_router"] diff --git a/backend/api/v1/predictions/prediction_routes.py b/backend/api/v1/predictions/prediction_routes.py new file mode 100644 index 0000000..bfb4d28 --- /dev/null +++ b/backend/api/v1/predictions/prediction_routes.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from fastapi import APIRouter, Body, HTTPException, status +from fastapi.exceptions import RequestValidationError +from fastapi.routing import APIRoute + +from prediction.model_registry import get_model_registry +from prediction.schemas import WeightEstimationRequest, WeightEstimationResponse + + +class UnprocessableEntityValidationRoute(APIRoute): + def get_route_handler(self): + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request): + try: + return await original_route_handler(request) + except RequestValidationError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=exc.errors(), + ) from exc + + return custom_route_handler + + +predictions_router = APIRouter( + prefix="/api/v1/predictions", + tags=["predictions"], + route_class=UnprocessableEntityValidationRoute, +) + +weight_estimation_predictor = get_model_registry().get_default() + +WEIGHT_ESTIMATION_REQUEST_EXAMPLE = { + "species": "cattle", + "breed": "minhota", + "sex": "female", + "age_months": 28, + "measurements": { + "body_length_cm": 152.4, + "withers_height_cm": 126.8, + "thoracic_depth_cm": 65.9, + "rump_width_cm": 50.2, + "chest_girth_cm": 194.3, + }, +} + + +@predictions_router.post( + "/weight-estimation", + response_model=WeightEstimationResponse, + status_code=status.HTTP_200_OK, + summary="Estimate livestock weight from morphometric measurements", +) +def estimate_weight( + payload: WeightEstimationRequest = Body( + ..., + openapi_examples={ + "baseline_formula_request": { + "summary": "Formula-based weight estimation request", + "value": WEIGHT_ESTIMATION_REQUEST_EXAMPLE, + }, + }, + ), +) -> WeightEstimationResponse: + try: + return weight_estimation_predictor.predict(payload) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(exc), + ) from exc diff --git a/backend/docs/weight_estimation_endpoint_notes.md b/backend/docs/weight_estimation_endpoint_notes.md new file mode 100644 index 0000000..467cadb --- /dev/null +++ b/backend/docs/weight_estimation_endpoint_notes.md @@ -0,0 +1,31 @@ +# Weight Estimation Endpoint Notes + +This endpoint adds a direct way to test formula-based livestock weight estimation without storing any data in the database. + +## What was added + +- A new route at `POST /api/v1/predictions/weight-estimation` +- Integration tests using FastAPI `TestClient` +- Swagger/OpenAPI request example for manual testing + +## Why this endpoint exists + +This route is useful as an isolated validation step before connecting weight estimation to persisted scan workflows. It lets the team verify request validation, response shape, diagnostics, and predictor integration without coupling the work to image upload or database persistence. + +## Request and response flow + +1. FastAPI receives the HTTP request body. +2. The `WeightEstimationRequest` Pydantic schema validates and parses the payload. +3. The route calls the existing formula-based predictor from `backend/prediction/`. +4. The predictor returns a `WeightEstimationResponse` with an estimated weight, conservative confidence score, and diagnostics. +5. FastAPI serializes the response into JSON and exposes it in Swagger. + +## Validation behavior + +- Valid payloads return `200 OK`. +- Invalid payloads for this route return `422 Unprocessable Entity`. +- The route uses a route-specific validation handler so the new endpoint can follow standard FastAPI `422` behavior without changing the validation behavior of unrelated legacy endpoints. + +## Important implementation detail + +The repository uses a shared Pydantic base model with camelCase aliases for JSON serialization. The prediction module keeps snake_case field names in Python code, while the API can still accept the snake_case example payload and serialize responses consistently with the existing project conventions. diff --git a/backend/main.py b/backend/main.py index fe852a1..0d4dede 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,6 +6,7 @@ from api.v1.organizations.organizations_routes import organizations_router from api.v1.users.users_routes import users_router from api.v1.organizations_members.organizations_members_routes import organizations_member_router +from api.v1.predictions.prediction_routes import predictions_router from core.database import initialize_database from core.errors import register_exception_handlers @@ -99,6 +100,7 @@ async def scan( app.include_router(organizations_router) app.include_router(users_router) app.include_router(organizations_member_router) + app.include_router(predictions_router) return app diff --git a/backend/prediction/__init__.py b/backend/prediction/__init__.py new file mode 100644 index 0000000..434959f --- /dev/null +++ b/backend/prediction/__init__.py @@ -0,0 +1,18 @@ +from prediction.model_registry import DEFAULT_MODEL_VERSION, get_model_registry +from prediction.predictor import FormulaBasedWeightPredictor +from prediction.schemas import ( + AnimalMeasurements, + PredictionDiagnostics, + WeightEstimationRequest, + WeightEstimationResponse, +) + +__all__ = [ + "AnimalMeasurements", + "DEFAULT_MODEL_VERSION", + "FormulaBasedWeightPredictor", + "PredictionDiagnostics", + "WeightEstimationRequest", + "WeightEstimationResponse", + "get_model_registry", +] diff --git a/backend/prediction/confidence.py b/backend/prediction/confidence.py new file mode 100644 index 0000000..34c4de1 --- /dev/null +++ b/backend/prediction/confidence.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from prediction.feature_builder import FormulaFeatures +from prediction.measurement_validator import MeasurementValidationResult + + +def calculate_confidence_score( + *, + validation_result: MeasurementValidationResult, + features: FormulaFeatures, +) -> float: + average_boundary_centrality = sum( + validation_result.boundary_centrality.values(), + ) / len(validation_result.boundary_centrality) + + ratio_alignment = _average( + [ + _ratio_alignment( + value=features.chest_girth_to_length_ratio, + target=1.28, + tolerance=0.35, + ), + _ratio_alignment( + value=features.withers_height_to_depth_ratio, + target=1.95, + tolerance=0.55, + ), + _ratio_alignment( + value=features.rump_width_to_girth_ratio, + target=0.29, + tolerance=0.12, + ), + ], + ) + + warning_penalty = min(0.20, 0.04 * len(validation_result.warnings)) + score = ( + 0.55 + + 0.08 * (average_boundary_centrality - 0.5) + + 0.06 * (ratio_alignment - 0.5) + - warning_penalty + ) + + bounded_score = max(0.20, min(0.70, score)) + return round(bounded_score, 2) + + +def _ratio_alignment(*, value: float, target: float, tolerance: float) -> float: + deviation = abs(value - target) + if deviation >= tolerance: + return 0.0 + return 1.0 - (deviation / tolerance) + + +def _average(values: list[float]) -> float: + return sum(values) / len(values) diff --git a/backend/prediction/feature_builder.py b/backend/prediction/feature_builder.py new file mode 100644 index 0000000..ee3cbc2 --- /dev/null +++ b/backend/prediction/feature_builder.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from prediction.schemas import AnimalMeasurements + + +CENTIMETERS_PER_INCH = 2.54 + + +@dataclass(frozen=True) +class FormulaFeatures: + heart_girth_in: float + body_length_in: float + chest_girth_to_length_ratio: float + withers_height_to_depth_ratio: float + rump_width_to_girth_ratio: float + + +def build_formula_features(measurements: AnimalMeasurements) -> FormulaFeatures: + return FormulaFeatures( + heart_girth_in=measurements.chest_girth_cm / CENTIMETERS_PER_INCH, + body_length_in=measurements.body_length_cm / CENTIMETERS_PER_INCH, + chest_girth_to_length_ratio=( + measurements.chest_girth_cm / measurements.body_length_cm + ), + withers_height_to_depth_ratio=( + measurements.withers_height_cm / measurements.thoracic_depth_cm + ), + rump_width_to_girth_ratio=( + measurements.rump_width_cm / measurements.chest_girth_cm + ), + ) diff --git a/backend/prediction/formula_baseline.py b/backend/prediction/formula_baseline.py new file mode 100644 index 0000000..2e52b03 --- /dev/null +++ b/backend/prediction/formula_baseline.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from prediction.feature_builder import FormulaFeatures + + +MODEL_VERSION = "formula-baseline-v0.1.0" +ESTIMATION_METHOD = "heart_girth_body_length_formula" +POUNDS_TO_KILOGRAMS = 0.45359237 + + +@dataclass(frozen=True) +class FormulaEstimation: + estimated_weight_lb: float + estimated_weight_kg: float + + +def estimate_weight_from_formula(features: FormulaFeatures) -> FormulaEstimation: + estimated_weight_lb = ( + (features.heart_girth_in ** 2) * features.body_length_in + ) / 300 + estimated_weight_kg = estimated_weight_lb * POUNDS_TO_KILOGRAMS + + return FormulaEstimation( + estimated_weight_lb=estimated_weight_lb, + estimated_weight_kg=estimated_weight_kg, + ) diff --git a/backend/prediction/measurement_validator.py b/backend/prediction/measurement_validator.py new file mode 100644 index 0000000..0084236 --- /dev/null +++ b/backend/prediction/measurement_validator.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from prediction.schemas import AnimalMeasurements + + +@dataclass(frozen=True) +class PlausibleRange: + minimum: float + maximum: float + warning_margin: float + + +@dataclass(frozen=True) +class MeasurementValidationResult: + measurements: AnimalMeasurements + warnings: list[str] + input_quality: str + boundary_centrality: dict[str, float] + + +PLAUSIBLE_BOVINE_RANGES: dict[str, PlausibleRange] = { + "body_length_cm": PlausibleRange(minimum=80.0, maximum=260.0, warning_margin=10.0), + "withers_height_cm": PlausibleRange( + minimum=80.0, + maximum=210.0, + warning_margin=8.0, + ), + "thoracic_depth_cm": PlausibleRange( + minimum=35.0, + maximum=120.0, + warning_margin=6.0, + ), + "rump_width_cm": PlausibleRange(minimum=20.0, maximum=100.0, warning_margin=5.0), + "chest_girth_cm": PlausibleRange( + minimum=90.0, + maximum=320.0, + warning_margin=12.0, + ), +} + + +def validate_measurements( + measurements: AnimalMeasurements, +) -> MeasurementValidationResult: + warnings: list[str] = [] + boundary_centrality: dict[str, float] = {} + + for field_name, plausible_range in PLAUSIBLE_BOVINE_RANGES.items(): + value = float(getattr(measurements, field_name)) + _validate_positive_measurement(field_name=field_name, value=value) + _validate_plausible_range( + field_name=field_name, + value=value, + plausible_range=plausible_range, + ) + if _is_near_boundary(value=value, plausible_range=plausible_range): + warnings.append( + f"{field_name} is close to the plausible bovine range boundary.", + ) + boundary_centrality[field_name] = _calculate_boundary_centrality( + value=value, + plausible_range=plausible_range, + ) + + warnings.extend(_build_consistency_warnings(measurements)) + input_quality = "valid_with_warnings" if warnings else "valid" + + return MeasurementValidationResult( + measurements=measurements, + warnings=warnings, + input_quality=input_quality, + boundary_centrality=boundary_centrality, + ) + + +def _validate_positive_measurement(*, field_name: str, value: float) -> None: + if value == 0: + raise ValueError(f"{field_name} must be greater than zero.") + if value < 0: + raise ValueError(f"{field_name} must be positive.") + + +def _validate_plausible_range( + *, + field_name: str, + value: float, + plausible_range: PlausibleRange, +) -> None: + if value < plausible_range.minimum or value > plausible_range.maximum: + raise ValueError( + f"{field_name}={value} is outside the plausible bovine range " + f"{plausible_range.minimum}-{plausible_range.maximum} cm.", + ) + + +def _is_near_boundary(*, value: float, plausible_range: PlausibleRange) -> bool: + return ( + value - plausible_range.minimum <= plausible_range.warning_margin + or plausible_range.maximum - value <= plausible_range.warning_margin + ) + + +def _calculate_boundary_centrality( + *, + value: float, + plausible_range: PlausibleRange, +) -> float: + half_span = (plausible_range.maximum - plausible_range.minimum) / 2 + midpoint = plausible_range.minimum + half_span + distance_from_center = abs(value - midpoint) + normalized_distance = min(distance_from_center / half_span, 1.0) + return 1.0 - normalized_distance + + +def _build_consistency_warnings(measurements: AnimalMeasurements) -> list[str]: + warnings: list[str] = [] + + chest_girth_to_length_ratio = ( + measurements.chest_girth_cm / measurements.body_length_cm + ) + if chest_girth_to_length_ratio < 0.85 or chest_girth_to_length_ratio > 1.70: + warnings.append( + "The chest girth to body length ratio " + "is unusual for bovine body proportions.", + ) + + height_to_depth_ratio = ( + measurements.withers_height_cm / measurements.thoracic_depth_cm + ) + if height_to_depth_ratio < 1.25 or height_to_depth_ratio > 3.20: + warnings.append( + "The withers height to thoracic depth ratio looks inconsistent.", + ) + + rump_width_to_girth_ratio = measurements.rump_width_cm / measurements.chest_girth_cm + if rump_width_to_girth_ratio < 0.15 or rump_width_to_girth_ratio > 0.45: + warnings.append( + "The rump width is not well aligned with the reported chest girth.", + ) + + if measurements.thoracic_depth_cm >= measurements.withers_height_cm: + warnings.append( + "Thoracic depth should normally remain below withers height.", + ) + + return warnings diff --git a/backend/prediction/model_registry.py b/backend/prediction/model_registry.py new file mode 100644 index 0000000..b9427f7 --- /dev/null +++ b/backend/prediction/model_registry.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Protocol + +from prediction.formula_baseline import MODEL_VERSION +from prediction.predictor import FormulaBasedWeightPredictor +from prediction.schemas import WeightEstimationRequest, WeightEstimationResponse + + +DEFAULT_MODEL_VERSION = MODEL_VERSION + + +class WeightPredictionModel(Protocol): + model_version: str + + def predict( + self, + request: WeightEstimationRequest, + ) -> WeightEstimationResponse: ... + + +class ModelRegistry: + def __init__(self) -> None: + self._models: dict[str, WeightPredictionModel] = {} + + def register(self, model: WeightPredictionModel) -> None: + self._models[model.model_version] = model + + def get(self, model_version: str) -> WeightPredictionModel: + if model_version not in self._models: + available_versions = ", ".join(sorted(self._models)) + raise KeyError( + f"Unknown model version '{model_version}'. " + f"Available versions: {available_versions}", + ) + return self._models[model_version] + + def get_default(self) -> WeightPredictionModel: + return self.get(DEFAULT_MODEL_VERSION) + + +def get_model_registry() -> ModelRegistry: + registry = ModelRegistry() + registry.register(FormulaBasedWeightPredictor()) + return registry diff --git a/backend/prediction/predictor.py b/backend/prediction/predictor.py new file mode 100644 index 0000000..6eca0bf --- /dev/null +++ b/backend/prediction/predictor.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from prediction.confidence import calculate_confidence_score +from prediction.feature_builder import build_formula_features +from prediction.formula_baseline import ( + ESTIMATION_METHOD, + MODEL_VERSION, + estimate_weight_from_formula, +) +from prediction.measurement_validator import validate_measurements +from prediction.schemas import ( + PredictionDiagnostics, + WeightEstimationRequest, + WeightEstimationResponse, +) + + +class FormulaBasedWeightPredictor: + model_version = MODEL_VERSION + estimation_method = ESTIMATION_METHOD + + def predict( + self, + request: WeightEstimationRequest, + ) -> WeightEstimationResponse: + validation_result = validate_measurements(request.measurements) + features = build_formula_features(validation_result.measurements) + estimation = estimate_weight_from_formula(features) + confidence_score = calculate_confidence_score( + validation_result=validation_result, + features=features, + ) + + warnings = list(validation_result.warnings) + warnings.extend(_build_context_warnings(request)) + warnings.append( + "Estimation is based on a morphometric formula " + "and has not been calibrated with ground-truth farm data.", + ) + + diagnostics = PredictionDiagnostics( + input_quality=validation_result.input_quality, + warnings=warnings, + requires_ground_truth_validation=True, + is_formula_based=True, + is_trained_model=False, + ) + + return WeightEstimationResponse( + estimated_weight_kg=round(estimation.estimated_weight_kg, 1), + confidence_score=confidence_score, + model_version=self.model_version, + estimation_method=self.estimation_method, + diagnostics=diagnostics, + ) + + +def _build_context_warnings(request: WeightEstimationRequest) -> list[str]: + warnings: list[str] = [] + normalized_species = request.species.strip().lower() + bovine_labels = { + "bovine", + "cattle", + "cow", + "bull", + "heifer", + "calf", + "bos taurus", + "bos indicus", + } + if normalized_species not in bovine_labels: + warnings.append( + "Plausibility checks are tuned for bovines " + "and may not generalize to other species.", + ) + + if request.age_months < 6: + warnings.append( + "Very young animals can deviate materially " + "from adult morphometric formulas.", + ) + + return warnings diff --git a/backend/prediction/schemas.py b/backend/prediction/schemas.py new file mode 100644 index 0000000..6770db4 --- /dev/null +++ b/backend/prediction/schemas.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import Field + +from schemas.base import APIModel + + +class AnimalMeasurements(APIModel): + body_length_cm: float = Field(gt=0) + withers_height_cm: float = Field(gt=0) + thoracic_depth_cm: float = Field(gt=0) + rump_width_cm: float = Field(gt=0) + chest_girth_cm: float = Field(gt=0) + + +class WeightEstimationRequest(APIModel): + species: str = Field(min_length=1) + breed: str = Field(min_length=1) + sex: str = Field(min_length=1) + age_months: int = Field(ge=0) + measurements: AnimalMeasurements + + +class PredictionDiagnostics(APIModel): + input_quality: Literal["valid", "valid_with_warnings"] + warnings: list[str] + requires_ground_truth_validation: bool + is_formula_based: bool + is_trained_model: bool + + +class WeightEstimationResponse(APIModel): + estimated_weight_kg: float + confidence_score: float + model_version: str + estimation_method: str + diagnostics: PredictionDiagnostics diff --git a/backend/tests/test_weight_estimation_agent.py b/backend/tests/test_weight_estimation_agent.py new file mode 100644 index 0000000..1e170e1 --- /dev/null +++ b/backend/tests/test_weight_estimation_agent.py @@ -0,0 +1,108 @@ +import unittest + +from pydantic import ValidationError + +from prediction.model_registry import DEFAULT_MODEL_VERSION, get_model_registry +from prediction.schemas import WeightEstimationRequest + + +class WeightEstimationAgentTests(unittest.TestCase): + def setUp(self): + self.predictor = get_model_registry().get_default() + + def build_request(self, **measurement_overrides) -> WeightEstimationRequest: + measurements = { + "body_length_cm": 150.0, + "withers_height_cm": 135.0, + "thoracic_depth_cm": 72.0, + "rump_width_cm": 52.0, + "chest_girth_cm": 188.0, + } + measurements.update(measurement_overrides) + return WeightEstimationRequest( + species="cattle", + breed="angus", + sex="female", + age_months=24, + measurements=measurements, + ) + + def test_valid_payload_returns_positive_estimated_weight(self): + response = self.predictor.predict(self.build_request()) + + self.assertGreater(response.estimated_weight_kg, 0) + self.assertEqual(response.model_version, DEFAULT_MODEL_VERSION) + self.assertEqual( + response.estimation_method, + "heart_girth_body_length_formula", + ) + + def test_negative_measurement_raises_validation_error(self): + with self.assertRaises(ValidationError): + self.build_request(body_length_cm=-1) + + def test_zero_measurement_raises_validation_error(self): + with self.assertRaises(ValidationError): + self.build_request(chest_girth_cm=0) + + def test_out_of_plausible_range_raises_error(self): + request = self.build_request(chest_girth_cm=340.0) + + with self.assertRaises(ValueError): + self.predictor.predict(request) + + def test_near_plausible_range_boundary_generates_warning(self): + response = self.predictor.predict(self.build_request(body_length_cm=82.0)) + + self.assertEqual(response.diagnostics.input_quality, "valid_with_warnings") + self.assertTrue( + any( + "range boundary" in warning + for warning in response.diagnostics.warnings + ), + ) + + def test_result_is_deterministic_for_same_input(self): + request = self.build_request() + + first_response = self.predictor.predict(request) + second_response = self.predictor.predict(request) + + self.assertEqual( + first_response.estimated_weight_kg, + second_response.estimated_weight_kg, + ) + self.assertEqual( + first_response.confidence_score, + second_response.confidence_score, + ) + self.assertEqual( + first_response.diagnostics.warnings, + second_response.diagnostics.warnings, + ) + + def test_confidence_score_stays_between_zero_and_one(self): + response = self.predictor.predict(self.build_request()) + + self.assertGreaterEqual(response.confidence_score, 0.0) + self.assertLessEqual(response.confidence_score, 1.0) + + def test_confidence_score_does_not_exceed_version_cap(self): + response = self.predictor.predict(self.build_request()) + + self.assertLessEqual(response.confidence_score, 0.70) + + def test_diagnostics_require_ground_truth_validation(self): + response = self.predictor.predict(self.build_request()) + + self.assertTrue(response.diagnostics.requires_ground_truth_validation) + + def test_diagnostics_identify_formula_based_estimation(self): + response = self.predictor.predict(self.build_request()) + + self.assertTrue(response.diagnostics.is_formula_based) + self.assertFalse(response.diagnostics.is_trained_model) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_weight_estimation_api.py b/backend/tests/test_weight_estimation_api.py new file mode 100644 index 0000000..a863826 --- /dev/null +++ b/backend/tests/test_weight_estimation_api.py @@ -0,0 +1,105 @@ +import os +import unittest + +os.environ["DATABASE_URL"] = "sqlite://" + +from fastapi.testclient import TestClient + +from main import app +from prediction.schemas import WeightEstimationResponse + + +class WeightEstimationApiTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.client = TestClient(app) + + def build_payload(self, **measurement_overrides): + measurements = { + "body_length_cm": 152.4, + "withers_height_cm": 126.8, + "thoracic_depth_cm": 65.9, + "rump_width_cm": 50.2, + "chest_girth_cm": 194.3, + } + measurements.update(measurement_overrides) + return { + "species": "cattle", + "breed": "minhota", + "sex": "female", + "age_months": 28, + "measurements": measurements, + } + + def test_valid_payload_returns_http_200(self): + response = self.client.post( + "/api/v1/predictions/weight-estimation", + json=self.build_payload(), + ) + + self.assertEqual(response.status_code, 200) + response_json = response.json() + parsed_response = WeightEstimationResponse.model_validate(response_json) + + self.assertIn("estimatedWeightKg", response_json) + self.assertIn("confidenceScore", response_json) + self.assertGreater(parsed_response.estimated_weight_kg, 0) + self.assertLessEqual(parsed_response.confidence_score, 0.70) + self.assertEqual( + parsed_response.model_version, + "formula-baseline-v0.1.0", + ) + self.assertTrue(parsed_response.diagnostics.requires_ground_truth_validation) + self.assertTrue(parsed_response.diagnostics.is_formula_based) + self.assertFalse(parsed_response.diagnostics.is_trained_model) + + def test_negative_measurement_returns_http_422(self): + response = self.client.post( + "/api/v1/predictions/weight-estimation", + json=self.build_payload(chest_girth_cm=-194.3), + ) + + self.assertEqual(response.status_code, 422) + + def test_zero_measurement_returns_http_422(self): + response = self.client.post( + "/api/v1/predictions/weight-estimation", + json=self.build_payload(body_length_cm=0), + ) + + self.assertEqual(response.status_code, 422) + + def test_missing_chest_girth_returns_http_422(self): + payload = self.build_payload() + payload["measurements"].pop("chest_girth_cm") + + response = self.client.post( + "/api/v1/predictions/weight-estimation", + json=payload, + ) + + self.assertEqual(response.status_code, 422) + + def test_missing_body_length_returns_http_422(self): + payload = self.build_payload() + payload["measurements"].pop("body_length_cm") + + response = self.client.post( + "/api/v1/predictions/weight-estimation", + json=payload, + ) + + self.assertEqual(response.status_code, 422) + + def test_endpoint_is_present_in_openapi_schema(self): + response = self.client.get("/openapi.json") + + self.assertEqual(response.status_code, 200) + self.assertIn( + "/api/v1/predictions/weight-estimation", + response.json()["paths"], + ) + + +if __name__ == "__main__": + unittest.main() From d2d0e819be9a8117d4b98b4af636915537565ab3 Mon Sep 17 00:00:00 2001 From: NicoleGomes Date: Thu, 11 Jun 2026 16:41:39 +0100 Subject: [PATCH 2/2] EURODEV-84: fix/fix-import-smoke-test-ruff-format --- .../organizations_members_routes.py | 1 + backend/core/config.py | 5 +- backend/core/database.py | 28 +++++++-- backend/main.py | 12 ++-- backend/prediction/formula_baseline.py | 4 +- backend/repositories/user_repository.py | 4 +- backend/schemas/organization_schemas.py | 4 +- .../services/organization_member_service.py | 10 +-- backend/tests/test_import_smoke.py | 63 +++++++++++++++++++ backend/tests/test_weight_estimation_agent.py | 3 +- 10 files changed, 110 insertions(+), 24 deletions(-) create mode 100644 backend/tests/test_import_smoke.py diff --git a/backend/api/v1/organizations_members/organizations_members_routes.py b/backend/api/v1/organizations_members/organizations_members_routes.py index 5aebf49..5d37245 100644 --- a/backend/api/v1/organizations_members/organizations_members_routes.py +++ b/backend/api/v1/organizations_members/organizations_members_routes.py @@ -17,6 +17,7 @@ tags=["members"], ) + @organizations_member_router.get( "/{organization_id}/members", response_model=list[OrganizationMemberResponse], diff --git a/backend/core/config.py b/backend/core/config.py index f35354d..5bd7e6e 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass +from functools import lru_cache from pathlib import Path from urllib.parse import quote_plus @@ -74,4 +75,6 @@ def _build_database_url() -> str: ).render_as_string(hide_password=False) -settings = Settings(database_url=_build_database_url()) +@lru_cache(maxsize=1) +def get_settings() -> Settings: + return Settings(database_url=_build_database_url()) diff --git a/backend/core/database.py b/backend/core/database.py index ac098c1..e2e50c9 100644 --- a/backend/core/database.py +++ b/backend/core/database.py @@ -6,12 +6,15 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, declarative_base, sessionmaker -from core.config import settings +from core.config import get_settings Base = declarative_base() +_engine: Engine | None = None +_session_local: sessionmaker | None = None def _create_engine() -> Engine: + settings = get_settings() connect_args = {} if settings.database_url.startswith("sqlite"): connect_args["check_same_thread"] = False @@ -34,12 +37,27 @@ def _enable_sqlite_foreign_keys(dbapi_connection, connection_record): cursor.close() -engine = _create_engine() -SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True) +def get_engine() -> Engine: + global _engine + if _engine is None: + _engine = _create_engine() + return _engine + + +def get_session_local() -> sessionmaker: + global _session_local + if _session_local is None: + _session_local = sessionmaker( + bind=get_engine(), + autoflush=False, + autocommit=False, + future=True, + ) + return _session_local def get_db() -> Generator[Session, None, None]: - db = SessionLocal() + db = get_session_local()() try: yield db finally: @@ -48,3 +66,5 @@ def get_db() -> Generator[Session, None, None]: def initialize_database() -> None: import models.models # noqa: F401 + + get_engine() diff --git a/backend/main.py b/backend/main.py index 0d4dede..a70aa46 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,7 +5,9 @@ from api.v1.organizations.organizations_routes import organizations_router from api.v1.users.users_routes import users_router -from api.v1.organizations_members.organizations_members_routes import organizations_member_router +from api.v1.organizations_members.organizations_members_routes import ( + organizations_member_router, +) from api.v1.predictions.prediction_routes import predictions_router from core.database import initialize_database from core.errors import register_exception_handlers @@ -31,12 +33,10 @@ def startup_event(): def health(): return {"status": "ok", "service": "PondiFarm API v0.1 — Phase 0"} - @app.get("/health") def health_check(): return {"status": "ok"} - @app.post("/api/v1/scan") async def scan( file: UploadFile = File(...), @@ -91,9 +91,9 @@ async def scan( "chest_girth_cm": measurements["chest_girth_cm"], }, "result": { - "estimated_weight_kg": peso_kg, - "confidence_pct": confianca, - "accuracy_note": "Estimativa por visão computacional 2D · Precisão aumentada com LiDAR na versão final", + "estimated_weight_kg": peso_kg, + "confidence_pct": confianca, + "accuracy_note": "Estimativa por visão computacional 2D · Precisão aumentada com LiDAR na versão final", }, } diff --git a/backend/prediction/formula_baseline.py b/backend/prediction/formula_baseline.py index 2e52b03..7ede64d 100644 --- a/backend/prediction/formula_baseline.py +++ b/backend/prediction/formula_baseline.py @@ -17,9 +17,7 @@ class FormulaEstimation: def estimate_weight_from_formula(features: FormulaFeatures) -> FormulaEstimation: - estimated_weight_lb = ( - (features.heart_girth_in ** 2) * features.body_length_in - ) / 300 + estimated_weight_lb = ((features.heart_girth_in**2) * features.body_length_in) / 300 estimated_weight_kg = estimated_weight_lb * POUNDS_TO_KILOGRAMS return FormulaEstimation( diff --git a/backend/repositories/user_repository.py b/backend/repositories/user_repository.py index aac6994..d3c80ee 100644 --- a/backend/repositories/user_repository.py +++ b/backend/repositories/user_repository.py @@ -18,9 +18,7 @@ def create_user(db: Session, user: User) -> User: def list_active_users(db: Session) -> list[User]: statement = ( - select(User) - .where(User.deleted_at.is_(None)) - .order_by(User.created_at.asc()) + select(User).where(User.deleted_at.is_(None)).order_by(User.created_at.asc()) ) return list(db.scalars(statement).all()) diff --git a/backend/schemas/organization_schemas.py b/backend/schemas/organization_schemas.py index 802ee3a..4f36eb3 100644 --- a/backend/schemas/organization_schemas.py +++ b/backend/schemas/organization_schemas.py @@ -23,7 +23,9 @@ def validate_portuguese_nif(value: str) -> str: if digits is None or len(digits) != 9 or digits[0] not in "1235689": raise ValueError("documentNumber must be a valid Portuguese NIF") - total = sum(int(digit) * weight for digit, weight in zip(digits[:8], range(9, 1, -1))) + total = sum( + int(digit) * weight for digit, weight in zip(digits[:8], range(9, 1, -1)) + ) remainder = total % 11 check_digit = 0 if remainder < 2 else 11 - remainder diff --git a/backend/services/organization_member_service.py b/backend/services/organization_member_service.py index c36bc85..0a33ffe 100644 --- a/backend/services/organization_member_service.py +++ b/backend/services/organization_member_service.py @@ -60,10 +60,12 @@ def create_member( get_organization_entity(db, organization_id) get_user_entity(db, payload.user_id) - existing = organization_member_repository.get_active_member_by_organization_and_user( - db, - organization_id, - payload.user_id, + existing = ( + organization_member_repository.get_active_member_by_organization_and_user( + db, + organization_id, + payload.user_id, + ) ) if existing: raise HTTPException( diff --git a/backend/tests/test_import_smoke.py b/backend/tests/test_import_smoke.py new file mode 100644 index 0000000..f4c5466 --- /dev/null +++ b/backend/tests/test_import_smoke.py @@ -0,0 +1,63 @@ +import subprocess +import textwrap +import unittest +from pathlib import Path + + +class MainImportSmokeTests(unittest.TestCase): + def test_import_main_does_not_require_database_env(self): + backend_dir = Path(__file__).resolve().parent.parent + python_executable = backend_dir / ".venv" / "Scripts" / "python.exe" + + script = textwrap.dedent( + """ + import os + import pathlib + + for key in ( + "DATABASE_URL", + "AZURE_SQL_SERVER", + "AZURE_SQL_DATABASE", + "AZURE_SQL_USERNAME", + "AZURE_SQL_PASSWORD", + "AZURE_SQL_DRIVER", + "AZURE_SQL_ENCRYPT", + "AZURE_SQL_TRUST_SERVER_CERTIFICATE", + ): + os.environ.pop(key, None) + + original_exists = pathlib.Path.exists + + def patched_exists(path): + if path.name == ".env" and path.parent.name == "backend": + return False + return original_exists(path) + + pathlib.Path.exists = patched_exists + + import main + + print("import-ok") + """, + ) + + completed = subprocess.run( + [str(python_executable), "-c", script], + capture_output=True, + text=True, + cwd=backend_dir, + check=False, + ) + + if completed.returncode != 0: + self.fail( + "Import smoke test failed.\n" + f"STDOUT:\n{completed.stdout}\n" + f"STDERR:\n{completed.stderr}", + ) + + self.assertIn("import-ok", completed.stdout) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_weight_estimation_agent.py b/backend/tests/test_weight_estimation_agent.py index 1e170e1..b22e33c 100644 --- a/backend/tests/test_weight_estimation_agent.py +++ b/backend/tests/test_weight_estimation_agent.py @@ -57,8 +57,7 @@ def test_near_plausible_range_boundary_generates_warning(self): self.assertEqual(response.diagnostics.input_quality, "valid_with_warnings") self.assertTrue( any( - "range boundary" in warning - for warning in response.diagnostics.warnings + "range boundary" in warning for warning in response.diagnostics.warnings ), )