Skip to content

Commit

Permalink
Merge pull request #197 from enoch3712/195-processextract-must-have-a…
Browse files Browse the repository at this point in the history
…-completion-strategy

Add strategy to Process.Extract
  • Loading branch information
enoch3712 authored Jan 17, 2025
2 parents c7f556f + 630bfb1 commit 6dd3cd9
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 14 deletions.
4 changes: 4 additions & 0 deletions extract_thinker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .document_loader.document_loader_pdfplumber import DocumentLoaderPdfPlumber
from .document_loader.document_loader_beautiful_soup import DocumentLoaderBeautifulSoup
from .document_loader.document_loader_markitdown import DocumentLoaderMarkItDown
from .document_loader.document_loader_docling import DocumentLoaderDocling
from .models.classification import Classification
from .models.classification_response import ClassificationResponse
from .process import Process
Expand All @@ -18,6 +19,7 @@
from .text_splitter import TextSplitter
from .models.contract import Contract
from .models.splitting_strategy import SplittingStrategy
from .models.completion_strategy import CompletionStrategy
from .batch_job import BatchJob
from .document_loader.document_loader_txt import DocumentLoaderTxt
from .document_loader.document_loader_doc2txt import DocumentLoaderDoc2txt
Expand Down Expand Up @@ -47,6 +49,8 @@
'DocumentLoaderDocumentAI',
'DocumentLoaderMarkItDown',
'Classification',
'CompletionStrategy',
'DocumentLoaderDocling',
'ClassificationResponse',
'Process',
'ClassificationStrategy',
Expand Down
19 changes: 12 additions & 7 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ async def extract_async(
source: Union[str, IO, list],
response_model: type[BaseModel],
vision: bool = False,
completion_strategy: Optional[CompletionStrategy] = CompletionStrategy.FORBIDDEN
) -> Any:
return await asyncio.to_thread(self.extract, source, response_model, vision)
return await asyncio.to_thread(self.extract, source, response_model, vision, "", completion_strategy)

def extract_with_strategy(
self,
Expand All @@ -265,13 +266,17 @@ def extract_with_strategy(
Returns:
Parsed response matching response_model
"""
# Get appropriate document loader
document_loader = self.get_document_loader(source)
if document_loader is None:
raise ValueError("No suitable document loader found for the input.")
# If source is already a list, use it directly
if isinstance(source, list):
content = source
else:
# Get appropriate document loader
document_loader = self.get_document_loader(source)
if document_loader is None:
raise ValueError("No suitable document loader found for the input.")

# Load content using list method
content = document_loader.load(source)
# Load content using list method
content = document_loader.load(source)

# Handle based on strategy
if completion_strategy == CompletionStrategy.PAGINATE:
Expand Down
14 changes: 11 additions & 3 deletions extract_thinker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from extract_thinker.image_splitter import ImageSplitter
from extract_thinker.models.classification_response import ClassificationResponse
from extract_thinker.models.classification_strategy import ClassificationStrategy
from extract_thinker.models.completion_strategy import CompletionStrategy
from extract_thinker.models.doc_groups2 import DocGroups2
from extract_thinker.models.splitting_strategy import SplittingStrategy
from extract_thinker.extractor import Extractor
Expand Down Expand Up @@ -232,15 +233,17 @@ def split(self, classifications: List[Classification], strategy: SplittingStrate

return self

def extract(self, vision: bool = False) -> List[Any]:
def extract(self,
vision: bool = False,
completion_strategy: Optional[CompletionStrategy] = CompletionStrategy.FORBIDDEN) -> List[Any]:
"""Extract information from the document groups."""
if self.doc_groups is None:
raise ValueError("Document groups have not been initialized")

async def _extract(doc_group):
# Find matching classification and extractor
classificationStr = doc_group.classification
extractor = None
extractor: Optional[Extractor] = None
contract = None

for classification in self.split_classifications:
Expand Down Expand Up @@ -271,7 +274,12 @@ async def _extract(doc_group):
# Set flag to skip loading since content is already processed
extractor.set_skip_loading(True)
try:
result = await extractor.extract_async(group_pages, contract, vision=vision)
result = await extractor.extract_async(
group_pages,
contract,
vision,
completion_strategy
)
finally:
# Reset flag after extraction
extractor.set_skip_loading(False)
Expand Down
143 changes: 139 additions & 4 deletions tests/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,62 @@
import os
from typing import Optional
from dotenv import load_dotenv
from extract_thinker.document_loader.document_loader_pypdf import DocumentLoaderPyPdf
from extract_thinker.extractor import Extractor
from tests.models.invoice import InvoiceContract
from extract_thinker import DocumentLoaderPyPdf
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling, DoclingConfig
from extract_thinker import Extractor
from extract_thinker import Contract
from extract_thinker import Classification
from extract_thinker import DocumentLoaderMarkItDown
from extract_thinker.models.completion_strategy import CompletionStrategy
from extract_thinker import SplittingStrategy
from extract_thinker import Process
from extract_thinker import TextSplitter
from extract_thinker import ImageSplitter
from pydantic import Field

from docling.datamodel.pipeline_options import (
PdfPipelineOptions,
TesseractCliOcrOptions,
TableStructureOptions,
)
from docling.datamodel.base_models import InputFormat
from docling.document_converter import PdfFormatOption

load_dotenv()
cwd = os.getcwd()

# Define the contracts as shown in the article
class InvoiceContract(Contract):
invoice_number: str = Field(description="Unique invoice identifier")
invoice_date: str = Field(description="Date of the invoice")
total_amount: float = Field(description="Overall total amount")

class VehicleRegistration(Contract):
name_primary: Optional[str] = Field(
default=None,
description="Primary registrant's name (Last, First, Middle)"
)
name_secondary: Optional[str] = Field(
default=None,
description="Co-registrant's name if applicable"
)
address: Optional[str] = Field(
default=None,
description="Primary registrant's mailing address including street, city, state and zip code"
)
vehicle_type: Optional[str] = Field(
default=None,
description="Type of vehicle (e.g., 2-Door, 4-Door, Pick-up, Van, etc.)"
)
vehicle_color: Optional[str] = Field(
default=None,
description="Primary color of the vehicle"
)

class DriverLicenseContract(Contract):
name: Optional[str] = Field(description="Full name on the license")
age: Optional[int] = Field(description="Age of the license holder")
license_number: Optional[str] = Field(description="License number")

def test_extract_with_ollama():
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
Expand All @@ -17,7 +67,7 @@ def test_extract_with_ollama():
)

os.environ["API_BASE"] = "http://localhost:11434"
extractor.load_llm("ollama/phi3.5")
extractor.load_llm("ollama/phi4")

# Act
result = extractor.extract(test_file_path, InvoiceContract)
Expand All @@ -26,3 +76,88 @@ def test_extract_with_ollama():
assert result is not None
assert result.invoice_number == "00012"
assert result.invoice_date == "1/30/23"

def test_extract_with_ollama_full_pipeline():
"""Test the complete document processing pipeline as described in the article"""
# Setup test file path
test_file_path = os.path.join(cwd, "tests", "files", "bulk.pdf")

# Create classifications
test_classifications = [
Classification(
name="Vehicle Registration",
description="This is a vehicle registration document",
contract=VehicleRegistration
),
Classification(
name="Driver License",
description="This is a driver license document",
contract=DriverLicenseContract
)
]

# Setup OCR options
ocr_options = TesseractCliOcrOptions(
force_full_page_ocr=True,
tesseract_cmd="/opt/homebrew/bin/tesseract"
)

# Setup pipeline options
pipeline_options = PdfPipelineOptions(
do_table_structure=True,
do_ocr=True,
ocr_options=ocr_options,
table_structure_options=TableStructureOptions(
do_cell_matching=True
)
)

# Create format options
format_options = {
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options
)
}

# Create docling config with OCR enabled
docling_config = DoclingConfig(
format_options=format_options,
ocr_enabled=True,
force_full_page_ocr=True
)

# Setup extractor with OCR-enabled docling loader
extractor = Extractor()
extractor.load_document_loader(DocumentLoaderDocling(docling_config))

# Configure Ollama
os.environ["API_BASE"] = "http://localhost:11434"
extractor.load_llm("ollama/phi4")

# Attach extractor to classifications
for classification in test_classifications:
classification.extractor = extractor

# Setup process
process = Process()
process.load_document_loader(DocumentLoaderDocling(docling_config))
process.load_splitter(ImageSplitter(model="claude-3-5-sonnet-20241022"))

test_classifications[0].extractor = extractor
test_classifications[1].extractor = extractor

# Run the complete pipeline
result = (
process
.load_file(test_file_path)
.split(test_classifications, strategy=SplittingStrategy.LAZY)
.extract(vision=False, completion_strategy=CompletionStrategy.PAGINATE)
)

# Assert
assert result is not None
assert isinstance(result, list)

# Check each extracted item
for item in result:
assert isinstance(item, (VehicleRegistration, DriverLicenseContract))

0 comments on commit 6dd3cd9

Please sign in to comment.