diff --git a/.github/workflows/build-backend.yml b/.github/workflows/build-backend.yml index 39a5fff..3c3ca51 100644 --- a/.github/workflows/build-backend.yml +++ b/.github/workflows/build-backend.yml @@ -25,7 +25,7 @@ jobs: run: | # Start container in background docker run -d --name spendoo-test \ - -e GROQ_API_KEY=${{ secrets.GROQ_API_KEY }} -e MISTRAL_API_KEY=${{ secrets.MISTRAL_API_KEY }} -e SPENDOO_DEPLOY=${{ secrets.SPENDOO_DEPLOY }} -e SPENDOO_ALLOWED_IP=${{ secrets.SPENDOO_ALLOWED_IP }} -p 8000:8000 spendoo-ai-backend + -e GROQ_API_KEY=${{ secrets.GROQ_API_KEY }} -e MISTRAL_API_KEY=${{ secrets.MISTRAL_API_KEY }} -e SPENDOO_DEPLOY=${{ secrets.SPENDOO_DEPLOY }} -e SPENDOO_ALLOWED_IP=${{ secrets.SPENDOO_ALLOWED_IP }} -e GEMINI_API_KEY=${{ secrets.GEMINI_API_KEY }} -p 8000:8000 spendoo-ai-backend # Wait for FastAPI to start sleep 10 # Check health endpoint diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 203d828..3d0941c 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -30,5 +30,5 @@ jobs: docker stop spendoo-ai-container || true docker rm spendoo-ai-container || true docker run -d --restart=always --name spendoo-ai-container -p 8000:8000 \ - -e GROQ_API_KEY="${{ secrets.GROQ_API_KEY }}" -e MISTRAL_API_KEY="${{ secrets.MISTRAL_API_KEY }}" -e SPENDOO_DEPLOY="${{ secrets.SPENDOO_DEPLOY }}" -e SPENDOO_ALLOWED_IP="${{ secrets.SPENDOO_ALLOWED_IP }}" josephsameh/spendoo-ai-backend:latest + -e GROQ_API_KEY="${{ secrets.GROQ_API_KEY }}" -e MISTRAL_API_KEY="${{ secrets.MISTRAL_API_KEY }}" -e SPENDOO_DEPLOY="${{ secrets.SPENDOO_DEPLOY }}" -e SPENDOO_ALLOWED_IP="${{ secrets.SPENDOO_ALLOWED_IP }}" -e GEMINI_API_KEY="${{ secrets.GEMINI_API_KEY }}" josephsameh/spendoo-ai-backend:latest docker image prune -f \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f246f9f..7af31ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ uvicorn>=0.15,<1 mistralai==1.12.4 python-dotenv>=0.19,<1 groq>=0.1,<1 -python-multipart>=0.0.5,<1 \ No newline at end of file +python-multipart>=0.0.5,<1 +google.genai>=1.69.0,<2 +flask>=3.0,<4 diff --git a/spendoo/core/config.py b/spendoo/core/config.py index 75feca5..9664d50 100644 --- a/spendoo/core/config.py +++ b/spendoo/core/config.py @@ -5,7 +5,21 @@ class Settings: GROQ_API_KEY: str = os.getenv("GROQ_API_KEY") + MISTRAL_API_KEY: str = os.getenv("MISTRAL_API_KEY") + + GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY") + GEMINI_OCR_MODELS = [ + "gemini-3.1-flash-lite-preview", + "gemini-2.5-flash", + ] + OCR_PROMPT = """ + Extract all line items from this receipt. Return ONLY a valid JSON array, no markdown, no explanation. + Format: [{"name": "item name (keep Arabic as-is)", "price": 12.50 (total price for that line, not per unit)}] + Return total amount paid if given and DO NOT include it as a line item. + Format: {"items": [...], "total": 12.50} + """ + VOICE_ALLOWED_EXTENSIONS: set[str] = {".wav", ".mp3", ".ogg", ".m4a", ".flac", ".webm"} VOICE_MAX_SIZE: int = 25 * 1024 * 1024 diff --git a/spendoo/ocr/pipeline.py b/spendoo/ocr/pipeline.py index 022518e..4d6ca2d 100644 --- a/spendoo/ocr/pipeline.py +++ b/spendoo/ocr/pipeline.py @@ -1,5 +1,4 @@ from spendoo.core.services import ServiceContainer -import base64 class ReceiptPipeline: @@ -10,11 +9,10 @@ def __init__(self): def process_receipt(self, image_bytes): - base64_image = base64.b64encode(image_bytes).decode("utf-8") # Step 1: OCR - receipt_text = self.ocr.extract_text(base64_image) + receipt_text = self.ocr.extract_text(image_bytes) # Step 2: LLM extraction - structured_data = self.extractor.extract(receipt_text) + structured_data = self.extractor.extract(str(receipt_text)) return structured_data \ No newline at end of file diff --git a/spendoo/ocr/service.py b/spendoo/ocr/service.py index 72458b3..6c971e2 100644 --- a/spendoo/ocr/service.py +++ b/spendoo/ocr/service.py @@ -1,26 +1,77 @@ -import base64 +from flask import json from mistralai import Mistral from spendoo.core.config import settings -from groq import Groq import base64 +from google import genai +from google.genai import types class OCRService: def __init__(self): - self.client = Mistral(api_key=settings.MISTRAL_API_KEY) + self.mistralClient = Mistral(api_key=settings.MISTRAL_API_KEY) + self.geminiClient = genai.Client(api_key=settings.GEMINI_API_KEY) + self.geminiModels = settings.GEMINI_OCR_MODELS + self.prompt = settings.OCR_PROMPT + + + def parse_json(self, response: str): + try: + cleaned_response = response.strip().removeprefix("```json").removesuffix("```").strip() + return json.loads(cleaned_response) + except json.JSONDecodeError as e: + print(f"JSON decoding error: {e}") + return None + + + def try_gemini_ocr(self, model: str, image_bytes: bytes): + response = self.geminiClient.models.generate_content( + model=model, + contents=[ + types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), + self.prompt + ] + ) + return self.parse_json(response.text) + - def extract_text(self, base64_image): + def try_mistral(self, image_bytes: bytes): + b64 = base64.b64encode(image_bytes).decode() + ocr_response = self.mistralClient.ocr.process( + model="mistral-ocr-latest", + document={ + "type": "image_url", + "image_url": f"data:image/jpeg;base64,{b64}" + }, + ) + ocr_text = "\n\n".join( + f"### Page {i + 1}\n{ocr_response.pages[i].markdown}" + for i in range(len(ocr_response.pages)) + ) - response = self.client.ocr.process( - model="mistral-ocr-latest", - document={ - "type": "image_url", - "image_url": f"data:image/jpeg;base64,{base64_image}" - }, - # table_format="markdown" + # Step 2: Parse items from OCR text using a Mistral chat model + parse_response = self.mistralClient.chat.complete( + model="mistral-small-latest", + messages=[{ + "role": "user", + "content": f"Receipt text:\n{ocr_text}\n\n{self.prompt}" + }] ) + return self.parse_json(parse_response.choices[0].message.content) + - return "\n\n".join( - f"### Page {i+1}\n{response.pages[i].markdown}" - for i in range(len(response.pages)) - ) \ No newline at end of file + def extract_text(self, image_bytes): + + # Try Gemini models first + for model in self.geminiModels: + try: + items = self.try_gemini_ocr(model, image_bytes) + return items + except Exception as e: + continue + + # Final fallback: Mistral OCR pipeline + try: + items = self.try_mistral(image_bytes) + return items + except Exception as e: + raise RuntimeError("Receipt extraction failed across all models.") from e