Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions backend/services/ocr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ async def extract_text(self, image_base64: str) -> str:
return ""

image_base64 = image_base64.strip()
if len(image_base64) > MAX_BASE64_LENGTH:
logger.warning(
"[OCRService] Rejected: base64 length %d exceeds limit %d",
len(image_base64), MAX_BASE64_LENGTH,
)
return ""

# ── 1. Strip data-URI prefix ─────────────────────────────────────────
content_type = None
Expand All @@ -77,17 +83,16 @@ async def extract_text(self, image_base64: str) -> str:
logger.warning("[OCRService] Rejected: unsupported content type %s", content_type)
return ""

# ── 2. Re-add padding ─────────────────────────────────────────────────
# ── 2. Normalise MIME-wrapped base64 ──────────────────────────────────
image_base64 = "".join(image_base64.split())
if not image_base64:
return ""

# ── 3. Re-add padding ─────────────────────────────────────────────────
missing_padding = len(image_base64) % 4
if missing_padding:
image_base64 += "=" * (4 - missing_padding)

# ── 3. Validate base64 characters ─────────────────────────────────────
_b64_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=")
if not all(c in _b64_chars for c in image_base64):
logger.warning("[OCRService] Rejected: invalid base64 characters detected.")
return ""

# ── 4. Base64 length guard ────────────────────────────────────────────
if len(image_base64) > MAX_BASE64_LENGTH:
logger.warning(
Expand All @@ -98,7 +103,7 @@ async def extract_text(self, image_base64: str) -> str:

try:
# ── 5. Decode ─────────────────────────────────────────────────────
image_bytes = base64.b64decode(image_base64)
image_bytes = base64.b64decode(image_base64, validate=True)

# ── 6. Decoded-bytes guard ────────────────────────────────────────
if len(image_bytes) > MAX_DECODED_BYTES:
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ def mock_ai_services(request):
"test_notification_routing.py",
"test_notification_routing_push.py",
"test_notification_routing_admin_alert.py",
"test_ocr_service.py",
"test_ocr_dos_protection.py",
}:
yield
return
Expand Down
20 changes: 20 additions & 0 deletions backend/tests/test_ocr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ async def test_strips_data_uri_prefix(self):
result = await svc.extract_text(f"data:image/png;base64,{b64}")
assert result == "hello"

@pytest.mark.asyncio
async def test_accepts_mime_wrapped_base64(self):
svc = OCRService()
tiny = _make_tiny_png_bytes()
wrapped = base64.encodebytes(tiny).decode()

with patch.object(svc, "_run_ocr", return_value=["wrapped"]):
result = await svc.extract_text(wrapped)
assert result == "wrapped"

@pytest.mark.asyncio
async def test_accepts_mime_wrapped_data_uri(self):
svc = OCRService()
tiny = _make_tiny_png_bytes()
wrapped = base64.encodebytes(tiny).decode()

with patch.object(svc, "_run_ocr", return_value=["uri wrapped"]):
result = await svc.extract_text(f"data:image/png;base64,{wrapped}")
assert result == "uri wrapped"

@pytest.mark.asyncio
async def test_rejects_unsupported_content_type(self):
svc = OCRService()
Expand Down