Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
enoch3712 committed Jun 13, 2024
2 parents 026baaa + fdfda7f commit d09a324
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 54 deletions.
84 changes: 66 additions & 18 deletions extract_thinker/document_loader/document_loader_tesseract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from io import BytesIO
from operator import attrgetter
import os
import threading
from typing import Any, List, Union
from PIL import Image
import pytesseract
Expand All @@ -10,17 +11,16 @@

from cachetools import cachedmethod
from cachetools.keys import hashkey
import concurrent.futures
from queue import Queue

SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff"]
SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff", "pdf"]


class DocumentLoaderTesseract(CachedDocumentLoader):
def __init__(self, tesseract_cmd, isContainer=False, content=None, cache_ttl=300):
super().__init__(content, cache_ttl)
self.tesseract_cmd = tesseract_cmd
if isContainer:
# docker path to tesseract
self.tesseract_cmd = os.environ.get("TESSERACT_PATH", "tesseract")
pytesseract.pytesseract.tesseract_cmd = self.tesseract_cmd
if not os.path.isfile(self.tesseract_cmd):
Expand Down Expand Up @@ -54,35 +54,83 @@ def load_content_from_stream(self, stream: Union[BytesIO, str]) -> Union[str, ob
except Exception as e:
raise Exception(f"Error processing stream: {e}") from e

def process_image(self, image):
def process_image(self, image: BytesIO) -> str:
for attempt in range(3):
raw_text = str(pytesseract.image_to_string(Image.open(BytesIO(image))))
if raw_text:
return raw_text
raise Exception("Failed to process image after 3 attempts")
try:
raw_text = str(pytesseract.image_to_string(Image.open(image)))
if raw_text:
return raw_text
except Exception as e:
if attempt == 2:
raise Exception(f"Failed to process image after 3 attempts: {e}")
return ""

def worker(self, input_queue: Queue, output_queue: Queue):
while True:
image = input_queue.get()
if image is None: # Sentinel to indicate shutdown
break
try:
text = self.process_image(image)
output_queue.put((image, text))
except Exception as e:
output_queue.put((image, str(e)))
input_queue.task_done()

@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))
def load_content_from_stream_list(self, stream: BytesIO) -> List[Any]:
images = self.convert_to_images(stream)
input_queue = Queue()
output_queue = Queue()

for img in images.values():
input_queue.put(BytesIO(img))

threads = []
for _ in range(4): # Number of worker threads
t = threading.Thread(target=self.worker, args=(input_queue, output_queue))
t.start()
threads.append(t)

input_queue.join()

for _ in range(4):
input_queue.put(None)

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {i: executor.submit(self.process_image, image[i]) for i, image in enumerate(images.values())}
for t in threads:
t.join()

contents = []
for i, future in futures.items():
contents.append({"image": images[i], "content": future.result()})
while not output_queue.empty():
image, content = output_queue.get()
contents.append({"image": image, "content": content})

return contents

@cachedmethod(cache=attrgetter('cache'), key=lambda self, input_list: hashkey(id(input_list)))
def load_content_from_file_list(self, input: List[Union[str, BytesIO]]) -> List[Any]:
images = self.convert_to_images(input)
input_queue = Queue()
output_queue = Queue()

for img in images.values():
input_queue.put(BytesIO(img))

threads = []
for _ in range(4): # Number of worker threads
t = threading.Thread(target=self.worker, args=(input_queue, output_queue))
t.start()
threads.append(t)

input_queue.join()

for _ in range(4):
input_queue.put(None)

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {i: executor.submit(self.process_image, image[i]) for i, image in enumerate(images.values())}
for t in threads:
t.join()

contents = []
for i, future in futures.items():
contents.append({"image": Image.open(BytesIO(images[i][i])), "content": future.result()})
while not output_queue.empty():
image, content = output_queue.get()
contents.append({"image": Image.open(image), "content": content})

return contents
10 changes: 8 additions & 2 deletions extract_thinker/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@


class LLM:
def __init__(self, model: str):
def __init__(self, model: str, api_base: str = None, api_key: str = None, api_version: str = None):
self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.MD_JSON)
self.model = model
self.router = None
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version

def load_router(self, router: Router) -> None:
self.router = router
Expand All @@ -31,7 +34,10 @@ def request(self, messages: List[Dict[str, str]], response_model: str) -> Any:
model=self.model,
max_tokens=max_tokens,
messages=messages,
response_model=response_model
response_model=response_model,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version
)

return response
Loading

0 comments on commit d09a324

Please sign in to comment.