Skip to content

Commit

Permalink
Merge pull request #279 from enoch3712/245-add-global-models-to-test
Browse files Browse the repository at this point in the history
refactor of the tests. document_loader multiple choice. Multi image fix
  • Loading branch information
enoch3712 authored Feb 21, 2025
2 parents fd24ba3 + b436c8b commit 5d06ed3
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 100 deletions.
4 changes: 4 additions & 0 deletions extract_thinker/document_loader/document_loader_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def can_handle(self, source: Any) -> bool:
return True
if isinstance(source, list) and all(isinstance(item, dict) for item in source):
return True
if isinstance(source, dict):
return True
return False

@cachedmethod(cache=attrgetter('cache'),
Expand Down Expand Up @@ -80,6 +82,8 @@ def load(self, source: Union[str, IO, List[Dict[str, Any]]]) -> List[Dict[str, A
return self._load_from_string(source)
elif hasattr(source, "read"):
return self._load_from_stream(source)
elif isinstance(source, dict):
return source

except Exception as e:
raise ValueError(f"Error processing content: {str(e)}")
Expand Down
69 changes: 54 additions & 15 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from instructor.batch import BatchJob
import uuid
from pydantic import BaseModel
from extract_thinker.document_loader.document_loader_data import DocumentLoaderData
from extract_thinker.llm_engine import LLMEngine
from extract_thinker.concatenation_handler import ConcatenationHandler
from extract_thinker.document_loader.document_loader import DocumentLoader
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
self.is_classify_image: bool = False
self._skip_loading: bool = False
self.chunk_height: int = 1500
self.allow_vision: bool = False

def add_interceptor(
self, interceptor: Union[LoaderInterceptor, LlmInterceptor]
Expand Down Expand Up @@ -85,7 +87,7 @@ def get_document_loader_for_file(self, source: Union[str, IO]) -> DocumentLoader

raise ValueError("No suitable document loader found for the input.")

def get_document_loader(self, source: Union[str, IO]) -> Optional[DocumentLoader]:
def get_document_loader(self, source: Union[str, IO, List[Union[str, IO]]]) -> Optional[DocumentLoader]:
"""
Retrieve the appropriate document loader for the given source.
Expand All @@ -110,6 +112,14 @@ def get_document_loader(self, source: Union[str, IO]) -> Optional[DocumentLoader
for loader in self.document_loaders_by_file_type.values():
if loader.can_handle(source):
return loader

# if is a list, usually coming from split, return documentLoaderData
if isinstance(source, List) or isinstance(source, dict):
return DocumentLoaderData()

# Last check, if allow vision just return the document loader llm image
if self.allow_vision:
return DocumentLoaderLLMImage()

return None

Expand Down Expand Up @@ -148,6 +158,36 @@ def set_skip_loading(self, skip: bool = True) -> None:
"""Internal method to control content loading behavior"""
self._skip_loading = skip

def remove_images_from_content(self, content: Union[Dict[str, Any], List[Dict[str, Any]], str]) -> Union[Dict[str, Any], List[Dict[str, Any]], str]:
"""
Remove image-related keys from the content while preserving the original structure.
Args:
content: Input content that can be a dictionary, list of dictionaries, or string
Returns:
Content with image-related keys removed, maintaining the original type
"""
if isinstance(content, dict):
# Create a deep copy to avoid modifying the original
content_copy = {
k: v for k, v in content.items()
if k not in ('images', 'image')
}
return content_copy

elif isinstance(content, list):
# Handle list of dictionaries
return [
self.remove_images_from_content(item)
if isinstance(item, (dict, list))
else item
for item in content
]

# Return strings or other types unchanged
return content

def extract(
self,
source: Union[str, IO, List[Union[str, IO]]],
Expand Down Expand Up @@ -176,12 +216,16 @@ def extract(
self._validate_dependencies(response_model, vision)
self.extra_content = content
self.completion_strategy = completion_strategy
self.allow_vision = vision

if vision:
try:
self._handle_vision_mode(source)
except ValueError as e:
raise InvalidVisionDocumentLoaderError(str(e))
else:
if isinstance(source, List):
source = self.remove_images_from_content(source)

if completion_strategy is not CompletionStrategy.FORBIDDEN:
return self.extract_with_strategy(source, response_model, vision, completion_strategy)
Expand Down Expand Up @@ -1149,23 +1193,18 @@ def _add_images_to_message_content(
content: Union[Dict[str, Any], List[Any]],
message_content: List[Dict[str, Any]],
) -> None:
"""
Add images to the message content.
Handles both legacy format and new page-based format from document loaders.
Args:
content: The content containing images.
message_content: The message content to append images to.
"""
if isinstance(content, list):
# Handle new page-based format
for page in content:
if isinstance(page, dict) and 'image' in page:
self._append_images(page['image'], message_content)
if isinstance(page, dict):
if 'image' in page:
self._append_images(page['image'], message_content)
if 'images' in page:
self._append_images(page['images'], message_content)
elif isinstance(content, dict):
# Handle legacy format
image_data = content.get('image') or content.get('images')
self._append_images(image_data[0], message_content)
if 'image' in content:
self._append_images(content['image'], message_content)
if 'images' in content:
self._append_images(content['images'], message_content)

def _append_images(
self,
Expand Down
5 changes: 3 additions & 2 deletions extract_thinker/global_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
def get_lite_model():
"""Return the lite model for cost efficiency."""
#return "vertex_ai/gemini-2.0-flash"
return "gpt-4o-mini"


def get_big_model():
"""Return the big model for high performance."""
return "gpt-4o"
#return "vertex_ai/gemini-2.0-flash"
return "gpt-4o"
3 changes: 0 additions & 3 deletions extract_thinker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from extract_thinker.models.classification_response import ClassificationResponse
from extract_thinker.models.classification_strategy import ClassificationStrategy
from extract_thinker.models.completion_strategy import CompletionStrategy
from extract_thinker.models.doc_groups2 import DocGroups2
from extract_thinker.models.splitting_strategy import SplittingStrategy
from extract_thinker.extractor import Extractor
from extract_thinker.models.classification import Classification
from extract_thinker.document_loader.document_loader import DocumentLoader
from extract_thinker.models.classification_tree import ClassificationTree
from extract_thinker.models.classification_node import ClassificationNode
from extract_thinker.models.doc_group import DocGroup
from extract_thinker.splitter import Splitter
from extract_thinker.models.doc_groups import (
DocGroups,
Expand Down
93 changes: 14 additions & 79 deletions tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
from tests.models.handbook_contract import HandbookContract
from extract_thinker.global_models import get_lite_model, get_big_model
from pydantic import BaseModel, Field
from extract_thinker.exceptions import ExtractThinkerError


load_dotenv()
cwd = os.getcwd()

def test_extract_with_pypdf_and_gpt4o_mini():
def test_extract_with_pypdf_and_gpt4o_mini_vision():

# Arrange
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
test_file_path = os.path.join(cwd, "tests", "test_images", "invoice.png")

extractor = Extractor()
extractor.load_document_loader(
Expand All @@ -35,51 +36,13 @@ def test_extract_with_pypdf_and_gpt4o_mini():
extractor.load_llm(get_lite_model())

# Act
result = extractor.extract(test_file_path, InvoiceContract)
result = extractor.extract(test_file_path, InvoiceContract, vision=True)

# Assert
assert result is not None
assert result.invoice_number == "0000001"
assert result.invoice_date == "2014-05-07"

def test_extract_with_azure_di_and_gpt4o_mini():
subscription_key = os.getenv("AZURE_SUBSCRIPTION_KEY")
endpoint = os.getenv("AZURE_ENDPOINT")
test_file_path = os.path.join(cwd, "tests", "test_images", "invoice.png")

extractor = Extractor()
extractor.load_document_loader(
DocumentLoaderAzureForm(subscription_key, endpoint)
)
extractor.load_llm(get_lite_model())
# Act
result = extractor.extract(test_file_path, InvoiceContract)

# Assert
assert result is not None
assert result.lines[0].description == "Website Redesign"
assert result.lines[0].quantity == 1
assert result.lines[0].unit_price == 2500
assert result.lines[0].amount == 2500

def test_extract_with_pypdf_and_gpt4o_mini():
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")

extractor = Extractor()
document_loader = DocumentLoaderPyPdf()
extractor.load_document_loader(document_loader)
extractor.load_llm("gpt-4o-mini")

# Act
result = extractor.extract(test_file_path, InvoiceContract, vision=True)

# Assert
assert result is not None
assert result.lines[0].description == "Consultation services"
assert result.lines[0].quantity == 3
assert result.lines[0].unit_price == 375
assert result.lines[0].amount == 1125

def test_vision_content_pdf():
# Arrange
extractor = Extractor()
Expand Down Expand Up @@ -156,10 +119,10 @@ def test_extract_with_invalid_file_path():
invalid_file_path = os.path.join(cwd, "tests", "nonexistent", "fake_file.png")

# Act & Assert
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ExtractThinkerError) as exc_info:
extractor.extract(invalid_file_path, InvoiceContract, vision=True)

assert "Failed to extract from source" in str(exc_info.value.args[0])
assert "Failed to extract from source: Cannot handle source" in str(exc_info.value)

def test_forbidden_strategy_with_token_limit():
test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "eu_tax_chart.png")
Expand Down Expand Up @@ -358,34 +321,6 @@ def test_llm_timeout():
result = extractor.extract(test_file_path, InvoiceContract)
assert result is not None

def test_dynamic_json_parsing():
"""Test dynamic JSON parsing with local Ollama model."""
# Initialize components
llm = LLM(model="ollama/deepseek-r1:1.5b")
llm.set_dynamic(True) # Enable dynamic JSON parsing

document_loader = DocumentLoaderPyPdf()
extractor = Extractor(document_loader=document_loader, llm=llm)

# Test content that should produce JSON response
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")

# Extract with dynamic parsing
try:
result = extractor.extract(test_file_path, InvoiceContract)

# Verify the result is an InvoiceContract instance
assert isinstance(result, InvoiceContract)

# Verify invoice fields
assert result.invoice_number is not None
assert result.invoice_date is not None
assert result.total_amount is not None
assert isinstance(result.lines, list)

except Exception as e:
pytest.fail(f"Dynamic JSON parsing test failed: {str(e)}")

def test_extract_with_default_backend():
"""Test extraction using default LiteLLM backend"""
# Arrange
Expand All @@ -407,8 +342,6 @@ def test_extract_with_default_backend():
def test_extract_with_pydanticai_backend():
"""Test extraction using PydanticAI backend if available"""
try:
import pydantic_ai

# Arrange
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")

Expand Down Expand Up @@ -439,13 +372,12 @@ def test_extract_from_url_docling_and_gpt4o_mini():
extractor = Extractor()
extractor.load_document_loader(DocumentLoaderDocling())
extractor.load_llm(get_lite_model())

# Act: Extract the document using the specified URL and the HandbookContract
result = extractor.extract(url, HandbookContract)
result: HandbookContract = extractor.extract(url, HandbookContract)

# Assert: Verify that the extracted title matches the expected value.
expected_title = "BCOBS 2A.1 Restriction on marketing or providing an optional product for which a fee is payable"
assert result.title == expected_title
# Check handbook data
assert "FCA Handbook" in result.title, f"Expected title to contain 'FCA Handbook', but got: {result.title}"

def test_extract_from_multiple_sources():
"""
Expand Down Expand Up @@ -480,4 +412,7 @@ class CombinedData(BaseModel):
assert result.total_amount == 1125

# Check handbook data
assert "FCA Handbook" in result.handbook_title, f"Expected title to contain 'FCA Handbook', but got: {result.handbook_title}"
assert "FCA Handbook" in result.handbook_title, f"Expected title to contain 'FCA Handbook', but got: {result.handbook_title}"

if __name__ == "__main__":
test_extract_with_invalid_file_path()
32 changes: 31 additions & 1 deletion tests/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
from typing import Optional
from dotenv import load_dotenv
import pytest
from extract_thinker import DocumentLoaderPyPdf
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling, DoclingConfig
from extract_thinker import Extractor
from extract_thinker import Contract
from extract_thinker import Classification
from extract_thinker import DocumentLoaderMarkItDown
from extract_thinker.llm import LLM
from extract_thinker.models.completion_strategy import CompletionStrategy
from extract_thinker import SplittingStrategy
from extract_thinker import Process
Expand Down Expand Up @@ -160,4 +162,32 @@ def test_extract_with_ollama_full_pipeline():

# Check each extracted item
for item in result:
assert isinstance(item, (VehicleRegistration, DriverLicenseContract))
assert isinstance(item, (VehicleRegistration, DriverLicenseContract))

def test_dynamic_json_parsing():
"""Test dynamic JSON parsing with local Ollama model."""
# Initialize components
llm = LLM(model="ollama/deepseek-r1:1.5b")
llm.set_dynamic(True) # Enable dynamic JSON parsing

document_loader = DocumentLoaderPyPdf()
extractor = Extractor(document_loader=document_loader, llm=llm)

# Test content that should produce JSON response
test_file_path = os.path.join(cwd, "tests", "files", "invoice.pdf")

# Extract with dynamic parsing
try:
result = extractor.extract(test_file_path, InvoiceContract)

# Verify the result is an InvoiceContract instance
assert isinstance(result, InvoiceContract)

# Verify invoice fields
assert result.invoice_number is not None
assert result.invoice_date is not None
assert result.total_amount is not None
assert isinstance(result.lines, list)

except Exception as e:
pytest.fail(f"Dynamic JSON parsing test failed: {str(e)}")
1 change: 1 addition & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def setup_process_and_classifications():

def test_eager_splitting_strategy():
"""Test eager splitting strategy with a multi-page document"""

# Arrange
process, classifications = setup_process_and_classifications()
process.load_splitter(ImageSplitter(get_big_model()))
Expand Down

0 comments on commit 5d06ed3

Please sign in to comment.