Skip to content

Commit

Permalink
Merge pull request #35 from enoch3712/run_tests_fix_classification
Browse files Browse the repository at this point in the history
Classification fix and tests running properly. Refactor of DocumentLo…
  • Loading branch information
enoch3712 authored Sep 26, 2024
2 parents 1718465 + 6ff78c1 commit 8332ec5
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 115 deletions.
2 changes: 0 additions & 2 deletions extract_thinker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .document_loader.document_loader_spreadsheet import DocumentLoaderSpreadSheet
from .document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm
from .document_loader.document_loader_pypdf import DocumentLoaderPyPdf
from .document_loader.document_loader_text import DocumentLoaderText
from .models import classification, classification_response
from .process import Process, ClassificationStrategy
from .splitter import Splitter
Expand All @@ -24,7 +23,6 @@
'DocumentLoaderSpreadSheet',
'DocumentLoaderAzureForm',
'DocumentLoaderPyPdf',
'DocumentLoaderText',
'classification',
'classification_response',
'Process',
Expand Down
2 changes: 1 addition & 1 deletion extract_thinker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# extractor.loadfile("C:\\Users\\Lopez\\Desktop\\MagniFinance\\examples\\outputTestOne.pdf").split(classifications)

extractor.load_document_loader(
DocumentLoaderTesseract("C:\\Program Files\\Tesseract-OCR\\tesseract.exe")
DocumentLoaderTesseract(os.getenv("TESSERACT_PATH"))
)
extractor.load_llm("claude-3-haiku-20240307")

Expand Down
29 changes: 28 additions & 1 deletion extract_thinker/document_loader/document_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,33 @@
import concurrent.futures
from typing import Any, Dict, List, Union
from cachetools import TTLCache

import os
from extract_thinker.utils import get_file_extension

class DocumentLoader(ABC):
def __init__(self, content: Any = None, cache_ttl: int = 300):
self.content = content
self.file_path = None
self.cache = TTLCache(maxsize=100, ttl=cache_ttl)

def can_handle(self, source: Union[str, BytesIO]) -> bool:
file_type = None
try:
if isinstance(source, str):
if not os.path.isfile(source):
return False
file_type = get_file_extension(source)
elif isinstance(source, BytesIO):
source.seek(0)
img = Image.open(source)
file_type = img.format.lower()
source.seek(0)
else:
return False
return file_type.lower() in [fmt.lower() for fmt in self.SUPPORTED_FORMATS]
except Exception:
return False

@abstractmethod
def load_content_from_file(self, file_path: str) -> Union[str, object]:
pass
Expand All @@ -22,6 +41,14 @@ def load_content_from_file(self, file_path: str) -> Union[str, object]:
def load_content_from_stream(self, stream: BytesIO) -> Union[str, object]:
pass

def load(self, source: Union[str, BytesIO]) -> Any:
if isinstance(source, str):
return self.load_content_from_file(source)
elif isinstance(source, BytesIO):
return self.load_content_from_stream(source)
else:
raise ValueError("Source must be a file path or a stream.")

def getContent(self) -> Any:
return self.content

Expand Down
4 changes: 3 additions & 1 deletion extract_thinker/document_loader/document_loader_pypdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Any, Dict, List, Union
from PyPDF2 import PdfReader
from extract_thinker.document_loader.document_loader_llm_image import DocumentLoaderLLMImage
from extract_thinker.utils import get_file_extension

SUPPORTED_FORMATS = ['pdf']

class DocumentLoaderPyPdf(DocumentLoaderLLMImage):
def __init__(self, content: Any = None, cache_ttl: int = 300):
Expand Down Expand Up @@ -38,4 +40,4 @@ def extract_data_from_pdf(self, reader: PdfReader) -> Union[str, Dict[str, Any]]
# if image_data:
# document_data["images"].append(image_data)

return document_data
return document_data
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
from cachetools import cachedmethod
from cachetools.keys import hashkey
from extract_thinker.utils import get_file_extension

SUPPORTED_FORMATS = ['xls', 'xlsx', 'xlsm', 'xlsb', 'odf', 'ods', 'odt', 'csv']

class DocumentLoaderSpreadSheet(CachedDocumentLoader):

def __init__(self, content=None, cache_ttl=300):
super().__init__(content, cache_ttl)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[
image, content = output_queue.get()
contents.append({"image": Image.open(image), "content": content})

return contents
return contents
24 changes: 0 additions & 24 deletions extract_thinker/document_loader/document_loader_text.py

This file was deleted.

11 changes: 0 additions & 11 deletions extract_thinker/document_loader/text_extract_loader.py

This file was deleted.

46 changes: 17 additions & 29 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@

from extract_thinker.utils import get_file_extension, encode_image, json_to_formatted_string
import yaml
import litellm

SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff"]
SUPPORTED_EXCEL_FORMATS = ['.xls', '.xlsx', '.xlsm', '.xlsb', '.odf', '.ods', '.odt', '.csv']


class Extractor:
def __init__(
self, processor: Optional[DocumentLoader] = None, llm: Optional[LLM] = None
self, document_loader: Optional[DocumentLoader] = None, llm: Optional[LLM] = None
):
self.document_loader: Optional[DocumentLoader] = processor
self.document_loader: Optional[DocumentLoader] = document_loader
self.llm: Optional[LLM] = llm
self.file: Optional[str] = None
self.document_loaders_by_file_type: Dict[str, DocumentLoader] = {}
Expand All @@ -47,10 +42,14 @@ def add_interceptor(
"Interceptor must be an instance of LoaderInterceptor or LlmInterceptor"
)

def set_document_loader_for_file_type(
self, file_type: str, document_loader: DocumentLoader
):
self.document_loaders_by_file_type[file_type] = document_loader
def get_document_loader_for_file(self, source: Union[str, IO]) -> DocumentLoader:
if self.document_loader and self.document_loader.can_handle(source):
return self.document_loader
else:
for loader in self.document_loaders_by_file_type.values():
if loader.can_handle(source):
return loader
raise ValueError("No suitable document loader found for the input.")

def get_document_loader_for_file(self, file: str) -> DocumentLoader:
_, ext = os.path.splitext(file)
Expand Down Expand Up @@ -229,23 +228,12 @@ def classify(self, input: Union[str, IO], classifications: List[Classification],
if image:
return self.classify_from_image(input, classifications)

if isinstance(input, str):
# Check if the input is a valid file path
if os.path.isfile(input):
file_type = get_file_extension(input)
if file_type == 'pdf':
return self.classify_from_path(input, classifications)
elif file_type in SUPPORTED_EXCEL_FORMATS:
return self.classify_from_excel(input, classifications)
else:
raise ValueError(f"Unsupported file type: {input}")
else:
raise ValueError(f"No such file: {input}")
elif hasattr(input, 'read'):
# Check if the input is a stream (like a file object)
return self.classify_from_stream(input, classifications)
else:
raise ValueError("Input must be a file path or a stream.")
document_loader = self.get_document_loader_for_file(input)
if document_loader is None:
raise ValueError("No suitable document loader found for the input.")

content = document_loader.load(input)
return self._classify(content, classifications)

async def classify_async(self, input: Union[str, IO], classifications: List[Classification]):
return await asyncio.to_thread(self.classify, input, classifications)
Expand All @@ -256,7 +244,7 @@ def _extract(self,
response_model,
vision=False,
is_stream=False
):
):
# call all the llm interceptors before calling the llm
for interceptor in self.llm_interceptors:
interceptor.intercept(self.llm)
Expand Down
22 changes: 16 additions & 6 deletions extract_thinker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ def is_pdf_stream(stream: Union[BytesIO, str]) -> bool:
# logger.error(f"Error checking if stream is PDF: {e}")
return False

def get_image_type(image_path):
def get_image_type(source):
try:
img = Image.open(image_path)
if isinstance(source, str):
img = Image.open(source)
elif isinstance(source, BytesIO):
source.seek(0)
img = Image.open(source)
source.seek(0)
else:
return None
return img.format.lower()
except IOError as e:
return f"An error occurred: {str(e)}"
return None

def verify_json(json_content: str):
try:
Expand Down Expand Up @@ -134,9 +141,12 @@ def extract_json(text):


def get_file_extension(file_path):
_, ext = os.path.splitext(file_path)
ext = ext[1:] # remove the dot
return ext
if isinstance(file_path, str):
_, ext = os.path.splitext(file_path)
ext = ext[1:] # remove the dot
return ext.lower()
else:
return None


def json_to_formatted_string(data):
Expand Down
2 changes: 1 addition & 1 deletion medium_posts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from utils import remove_json_format

# local path to tesseract
pytesseract.pytesseract.tesseract_cmd = 'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
pytesseract.pytesseract.tesseract_cmd = os.getenv("TESSERACT_PATH")
# docker path to tesseract
#os.environ.get('TESSERACT_PATH', 'tesseract')

Expand Down
77 changes: 39 additions & 38 deletions tests/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,44 @@ def arrange_process_with_extractors():
return process


def setup_process_with_textract_extractor():
"""Sets up and returns a process configured with only the Textract extractor."""
# Initialize the Textract document loader
document_loader = DocumentLoaderAWSTextract()

# Initialize the Textract extractor
textract_extractor = Extractor(document_loader)
textract_extractor.load_llm("gpt-4o")

# Create the process with only the Textract extractor
process = Process()
process.add_classify_extractor([[textract_extractor]])

return process


def setup_process_with_gpt4_extractor():
"""Sets up and returns a process configured with only the GPT-4 extractor."""
tesseract_path = os.getenv("TESSERACT_PATH")
if not tesseract_path:
raise ValueError("TESSERACT_PATH environment variable is not set")
print(f"Tesseract path: {tesseract_path}")
document_loader = DocumentLoaderTesseract(tesseract_path)

# Initialize the GPT-4 extractor
gpt_4_extractor = Extractor(document_loader)
gpt_4_extractor.load_llm("gpt-4o")

# Create the process with only the GPT-4 extractor
process = Process()
process.add_classify_extractor([[gpt_4_extractor]])

return process


def test_classify_feature():
"""Test classification using a single feature."""
extractor = setup_extractors()[1] # Using the second configured extractor
extractor = setup_extractors()[1]
result = extractor.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS)

assert result is not None
Expand Down Expand Up @@ -100,7 +135,7 @@ def test_classify_higher_order():
def test_classify_both():
"""Test classification using both consensus and higher order strategies with a threshold."""
process = arrange_process_with_extractors()
result = process.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS, strategy=ClassificationStrategy.BOTH, threshold=9)
result = process.classify(INVOICE_FILE_PATH, COMMON_CLASSIFICATIONS, strategy=ClassificationStrategy.CONSENSUS_WITH_THRESHOLD, threshold=9)

assert result is not None
assert isinstance(result, ClassificationResponse)
Expand All @@ -121,37 +156,6 @@ def test_with_contract():
assert result.name == "Invoice"


def setup_process_with_textract_extractor():
"""Sets up and returns a process configured with only the Textract extractor."""
# Initialize the Textract document loader
document_loader = DocumentLoaderAWSTextract()

# Initialize the Textract extractor
textract_extractor = Extractor(document_loader)
textract_extractor.load_llm("gpt-4o")

# Create the process with only the Textract extractor
process = Process()
process.add_classify_extractor([[textract_extractor]])

return process

def setup_process_with_gpt4_extractor():
"""Sets up and returns a process configured with only the GPT-4 extractor."""
tesseract_path = os.getenv("TESSERACT_PATH")
document_loader = DocumentLoaderTesseract(tesseract_path)

# Initialize the GPT-4 extractor
gpt_4_extractor = Extractor(document_loader)
gpt_4_extractor.load_llm("gpt-4o")

# Create the process with only the GPT-4 extractor
process = Process()
process.add_classify_extractor([[gpt_4_extractor]])

return process


def test_with_image():
"""Test classification using both consensus and higher order strategies with a threshold."""
process = setup_process_with_gpt4_extractor()
Expand All @@ -168,6 +172,7 @@ def test_with_image():
assert isinstance(result, ClassificationResponse)
assert result.name == "Invoice"


def test_with_tree():
"""Test classification using the tree strategy"""
process = setup_process_with_gpt4_extractor()
Expand Down Expand Up @@ -228,8 +233,4 @@ def test_with_tree():
result = process.classify(pdf_path, classification_tree, threshold=0.8)

assert result is not None
assert result.name == "Invoice"


if __name__ == "__main__":
test_classify_feature()
assert result.name == "Invoice"

0 comments on commit 8332ec5

Please sign in to comment.