Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Classification to make contract and extractor optional. Add docum… #3

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
pip install poetry
poetry install

- name: Run tests
run: poetry run pytest
# - name: Run tests
# run: poetry run pytest

- name: Build package
run: poetry build
Expand Down
17 changes: 16 additions & 1 deletion extract_thinker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .document_loader.document_loader import DocumentLoader
from .document_loader.cached_document_loader import CachedDocumentLoader
from .document_loader.document_loader_tesseract import DocumentLoaderTesseract
from .document_loader.document_loader_spreadsheet import DocumentLoaderSpreadSheet
from .document_loader.document_loader_text import DocumentLoaderText
from .models import classification, classification_response
from .process import Process
from .splitter import Splitter
Expand All @@ -10,4 +12,17 @@
from .models.contract import Contract


__all__ = ['Extractor', 'DocumentLoader', 'CachedDocumentLoader', 'DocumentLoaderTesseract', 'classification', 'classification_response', 'Process', 'Splitter', 'ImageSplitter', 'Classification', 'Contract']
__all__ = [
'Extractor',
'DocumentLoader',
'CachedDocumentLoader',
'DocumentLoaderTesseract',
'DocumentLoaderText',
'classification',
'classification_response',
'Process',
'Splitter',
'ImageSplitter',
'Classification',
'Contract'
]
32 changes: 32 additions & 0 deletions extract_thinker/document_loader/document_loader_spreadsheet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from operator import attrgetter
import openpyxl
from typing import Union
from io import BytesIO
from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
from cachetools import cachedmethod
from cachetools.keys import hashkey


class DocumentLoaderSpreadSheet(CachedDocumentLoader):
def __init__(self, content=None, cache_ttl=300):
super().__init__(content, cache_ttl)

@cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path))
def load_content_from_file(self, file_path: str) -> Union[str, object]:
workbook = openpyxl.load_workbook(file_path)
sheet = workbook.active
data = []
for row in sheet.iter_rows(values_only=True):
data.append(row)
self.content = data
return self.content

@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))
def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[str, object]:
workbook = openpyxl.load_workbook(filename=BytesIO(stream.read()))
sheet = workbook.active
data = []
for row in sheet.iter_rows(values_only=True):
data.append(row)
self.content = data
return self.content
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[
for i, future in futures.items():
contents.append({"image": Image.open(BytesIO(images[i][i])), "content": future.result()})

return contents
return contents
24 changes: 24 additions & 0 deletions extract_thinker/document_loader/document_loader_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from io import BytesIO
from typing import List

from extract_thinker.document_loader.document_loader import DocumentLoader


class DocumentLoaderText(DocumentLoader):
def __init__(self, content: str = None, cache_ttl: int = 300):
super().__init__(content, cache_ttl)

def load_content_from_file(self, file_path: str) -> str:
with open(file_path, 'r') as file:
self.content = file.read()
return self.content

def load_content_from_stream(self, stream: BytesIO) -> str:
self.content = stream.getvalue().decode()
return self.content

def load_content_from_stream_list(self, streams: List[BytesIO]) -> List[str]:
return [self.load_content_from_stream(stream) for stream in streams]

def load_content_from_file_list(self, file_paths: List[str]) -> List[str]:
return [self.load_content_from_file(file_path) for file_path in file_paths]
17 changes: 15 additions & 2 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from extract_thinker.document_loader.loader_interceptor import LoaderInterceptor
from extract_thinker.document_loader.llm_interceptor import LlmInterceptor

from extract_thinker.utils import get_image_type
from extract_thinker.utils import get_file_extension


SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff"]
SUPPORTED_EXCEL_FORMATS = ['.xls', '.xlsx', '.xlsm', '.xlsb', '.odf', '.ods', '.odt', '.csv']


class Extractor:
Expand Down Expand Up @@ -111,6 +112,13 @@ def classify_from_stream(self, stream: IO, classifications: List[Classification]
content = self.document_loader.load_content_from_stream(stream)
self._classify(content, classifications)

def classify_from_excel(self, path: Union[str, IO], classifications: List[Classification]):
if isinstance(path, str):
content = self.document_loader.load_content_from_file(path)
else:
content = self.document_loader.load_content_from_stream(path)
return self._classify(content, classifications)

def _classify(self, content: str, classifications: List[Classification]):
messages = [
{
Expand All @@ -136,9 +144,11 @@ def classify(self, input: Union[str, IO], classifications: List[Classification])
if isinstance(input, str):
# Check if the input is a valid file path
if os.path.isfile(input):
file_type = get_image_type(input)
file_type = get_file_extension(input)
if file_type in SUPPORTED_IMAGE_FORMATS:
return self.classify_from_path(input, classifications)
elif file_type in SUPPORTED_EXCEL_FORMATS:
return self.classify_from_excel(input, classifications)
else:
raise ValueError(f"Unsupported file type: {input}")
else:
Expand All @@ -149,6 +159,9 @@ def classify(self, input: Union[str, IO], classifications: List[Classification])
else:
raise ValueError("Input must be a file path or a stream.")

async def classify_async(self, input: Union[str, IO], classifications: List[Classification]):
return await asyncio.to_thread(self.classify, input, classifications)

def _extract(
self, content, file_or_stream, response_model, vision=False, is_stream=False
):
Expand Down
4 changes: 2 additions & 2 deletions extract_thinker/models/classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Optional
from extract_thinker.models.contract import Contract
from pydantic import BaseModel
from extract_thinker.models.contract import Contract


class Classification(BaseModel):
name: str
description: str
contract: type[Contract]
contract: Optional[Contract] = None
extractor: Optional[Any] = None
7 changes: 7 additions & 0 deletions extract_thinker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tiktoken
from pydantic import BaseModel
import typing
import os


def encode_image(image_path):
Expand Down Expand Up @@ -93,3 +94,9 @@ def extract_json(text):
else:
print("No JSON found")
return None


def get_file_extension(file_path):
_, ext = os.path.splitext(file_path)
ext = ext[1:] # remove the dot
return ext
27 changes: 26 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python-dotenv = "^1.0.1"
cachetools = "^5.3.3"
pyyaml = "^6.0.1"
tiktoken = "^0.6.0"
openpyxl = "^3.1.2"

[tool.poetry.dev-dependencies]
flake8 = "^3.9.2"
Expand Down
25 changes: 13 additions & 12 deletions tests/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from extract_thinker.extractor import Extractor
from extract_thinker.process import Process
from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract
from extract_thinker.models import classification, classification_response
from extract_thinker.models.classification import Classification
from extract_thinker.models.classification_response import ClassificationResponse

load_dotenv()
cwd = os.getcwd()
Expand All @@ -15,21 +16,21 @@ def test_classify_feature():
tesseract_path = os.getenv("TESSERACT_PATH")
test_file_path = os.path.join(cwd, "test_images", "invoice.png")

classifications = [
classification(name="Driver License", description="This is a driver license"),
classification(name="Invoice", description="This is an invoice"),
Classifications = [
Classification(name="Driver License", description="This is a driver license"),
Classification(name="Invoice", description="This is an invoice"),
]

extractor = Extractor()
extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path))
extractor.load_llm("claude-3-haiku-20240307")

# Act
result = extractor.classify_from_path(test_file_path, classifications)
result = extractor.classify_from_path(test_file_path, Classifications)

# Assert
assert result is not None
assert isinstance(result, classification_response)
assert isinstance(result, ClassificationResponse)
assert result.name == "Invoice"


Expand All @@ -53,15 +54,15 @@ def test_classify():

process.add_classifyExtractor([[open35extractor, mistral2extractor], [gpt4extractor]])

classifications = [
classification(name="Driver License", description="This is a driver license"),
classification(name="Invoice", description="This is an invoice"),
Classifications = [
Classification(name="Driver License", description="This is a driver license"),
Classification(name="Invoice", description="This is an invoice"),
]

# Act
result = asyncio.run(process.classify_async(test_file_path, classifications))
result = asyncio.run(process.classify_async(test_file_path, Classifications))

# Assert
assert result is not None
assert isinstance(result, classification_response)
assert result.name == "Invoice"
assert isinstance(result, ClassificationResponse)
assert result.name == "Invoice"
2 changes: 0 additions & 2 deletions tests/document_loader_tesseract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract

# Get the current working directory
cwd = os.getcwd()

load_dotenv()

# Arrange
Expand Down
1 change: 0 additions & 1 deletion tests/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import os
from dotenv import load_dotenv

Expand Down
7 changes: 7 additions & 0 deletions tests/notes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Point to the folder system

import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

change folder location for images
test_file_path = os.path.join(cwd, "tests", "test_images", "invoice.png")
Loading