diff --git a/backend/services/ocr_service.py b/backend/services/ocr_service.py index 1887a7674..ca61728ec 100644 --- a/backend/services/ocr_service.py +++ b/backend/services/ocr_service.py @@ -65,6 +65,12 @@ async def extract_text(self, image_base64: str) -> str: return "" image_base64 = image_base64.strip() + if len(image_base64) > MAX_BASE64_LENGTH: + logger.warning( + "[OCRService] Rejected: base64 length %d exceeds limit %d", + len(image_base64), MAX_BASE64_LENGTH, + ) + return "" # ── 1. Strip data-URI prefix ───────────────────────────────────────── content_type = None @@ -77,17 +83,16 @@ async def extract_text(self, image_base64: str) -> str: logger.warning("[OCRService] Rejected: unsupported content type %s", content_type) return "" - # ── 2. Re-add padding ───────────────────────────────────────────────── + # ── 2. Normalise MIME-wrapped base64 ────────────────────────────────── + image_base64 = "".join(image_base64.split()) + if not image_base64: + return "" + + # ── 3. Re-add padding ───────────────────────────────────────────────── missing_padding = len(image_base64) % 4 if missing_padding: image_base64 += "=" * (4 - missing_padding) - # ── 3. Validate base64 characters ───────────────────────────────────── - _b64_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=") - if not all(c in _b64_chars for c in image_base64): - logger.warning("[OCRService] Rejected: invalid base64 characters detected.") - return "" - # ── 4. Base64 length guard ──────────────────────────────────────────── if len(image_base64) > MAX_BASE64_LENGTH: logger.warning( @@ -98,7 +103,7 @@ async def extract_text(self, image_base64: str) -> str: try: # ── 5. Decode ───────────────────────────────────────────────────── - image_bytes = base64.b64decode(image_base64) + image_bytes = base64.b64decode(image_base64, validate=True) # ── 6. Decoded-bytes guard ──────────────────────────────────────── if len(image_bytes) > MAX_DECODED_BYTES: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 2baed5dc2..42d9f1905 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -528,6 +528,8 @@ def mock_ai_services(request): "test_notification_routing.py", "test_notification_routing_push.py", "test_notification_routing_admin_alert.py", + "test_ocr_service.py", + "test_ocr_dos_protection.py", }: yield return diff --git a/backend/tests/test_ocr_service.py b/backend/tests/test_ocr_service.py index 5fd13f397..5444c4b34 100644 --- a/backend/tests/test_ocr_service.py +++ b/backend/tests/test_ocr_service.py @@ -83,6 +83,26 @@ async def test_strips_data_uri_prefix(self): result = await svc.extract_text(f"data:image/png;base64,{b64}") assert result == "hello" + @pytest.mark.asyncio + async def test_accepts_mime_wrapped_base64(self): + svc = OCRService() + tiny = _make_tiny_png_bytes() + wrapped = base64.encodebytes(tiny).decode() + + with patch.object(svc, "_run_ocr", return_value=["wrapped"]): + result = await svc.extract_text(wrapped) + assert result == "wrapped" + + @pytest.mark.asyncio + async def test_accepts_mime_wrapped_data_uri(self): + svc = OCRService() + tiny = _make_tiny_png_bytes() + wrapped = base64.encodebytes(tiny).decode() + + with patch.object(svc, "_run_ocr", return_value=["uri wrapped"]): + result = await svc.extract_text(f"data:image/png;base64,{wrapped}") + assert result == "uri wrapped" + @pytest.mark.asyncio async def test_rejects_unsupported_content_type(self): svc = OCRService()