Skip to content

Commit

Permalink
Merge pull request #239 from enoch3712/230-document-loader---image-ge…
Browse files Browse the repository at this point in the history
…neration-for-url

Image generation for URL in DocumentLoader
  • Loading branch information
enoch3712 authored Feb 4, 2025
2 parents 961f8d3 + 060b547 commit a3e052d
Show file tree
Hide file tree
Showing 12 changed files with 401 additions and 82 deletions.
144 changes: 136 additions & 8 deletions extract_thinker/document_loader/document_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,36 @@
from io import BytesIO
from PIL import Image
import pypdfium2 as pdfium
from typing import Any, Dict, Union
from typing import Any, Dict, Union, List
from cachetools import TTLCache
import os
import magic
from extract_thinker.utils import get_file_extension, check_mime_type
from playwright.sync_api import sync_playwright
from urllib.parse import urlparse
import base64
import math

class DocumentLoader(ABC):
def __init__(self, content: Any = None, cache_ttl: int = 300):
# SUPPORTED_FORMATS = [
# "pdf", "jpg", "jpeg", "png", "tiff", "bmp"
# ]

def __init__(self, content: Any = None, cache_ttl: int = 300, screenshot_timeout: int = 1000):
"""Initialize loader.
Args:
content: Initial content
cache_ttl: Cache time-to-live in seconds
screenshot_timeout: Timeout in milliseconds to wait for page content load when capturing a screenshot.
"""
self.content = content
self.file_path = None
self.cache = TTLCache(maxsize=100, ttl=cache_ttl)
self.vision_mode = False
self.max_image_size = None # Changed to None by default
self.is_url = False # Indicates if the source is a URL
self.screenshot_timeout = screenshot_timeout

def set_max_image_size(self, size: int) -> None:
"""Set the maximum image size."""
Expand All @@ -31,6 +42,10 @@ def set_vision_mode(self, enabled: bool = True) -> None:
"""Enable or disable vision mode processing."""
self.vision_mode = enabled

def set_screenshot_timeout(self, timeout: int) -> None:
"""Set the screenshot timeout in milliseconds for capturing a screenshot from a URL."""
self.screenshot_timeout = timeout

def can_handle(self, source: Union[str, BytesIO]) -> bool:
"""
Checks if the loader can handle the given source.
Expand Down Expand Up @@ -60,7 +75,6 @@ def _can_handle_file_path(self, file_path: str) -> bool:
def _can_handle_stream(self, stream: BytesIO) -> bool:
"""Checks if the loader can handle the given BytesIO stream."""
try:
# Read the first few bytes to determine file type
mime = magic.from_buffer(stream.getvalue(), mime=True)
stream.seek(0) # Reset stream position
return check_mime_type(mime, self.SUPPORTED_FORMATS)
Expand All @@ -85,19 +99,36 @@ def convert_to_images(self, file: Union[str, io.BytesIO, io.BufferedReader], sca
raise TypeError("file must be a file path (str) or a file-like stream")

def _convert_file_to_images(self, file_path: str, scale: float) -> Dict[int, bytes]:
# Check if the file is already an image
"""Convert file to images, handling both URLs and local files."""
# Check if it's a URL
if self._is_url(file_path):
self.is_url = True # Set the instance variable if the source is a URL
try:
screenshot = self._capture_screenshot_from_url(file_path)
# Convert screenshot to PIL Image for potential resizing
img = Image.open(BytesIO(screenshot))
img = self._resize_if_needed(img)

# Split into vertical chunks
chunks = self._split_image_vertically(img)

# Return dictionary with chunks as list
return {0: chunks} # All chunks from URL are considered "page 0"

except Exception as e:
raise ValueError(f"Failed to capture screenshot from URL: {str(e)}")

# Existing code for local files...
try:
Image.open(file_path)
is_image = True
except IOError:
is_image = False

if is_image:
# If it is, return it as is
with open(file_path, "rb") as f:
return {0: f.read()}

# If it's not an image, proceed with the conversion
return self._convert_pdf_to_images(pdfium.PdfDocument(file_path), scale)

def _convert_stream_to_images(self, file_stream: io.BytesIO, scale: float) -> Dict[int, bytes]:
Expand Down Expand Up @@ -163,13 +194,15 @@ def can_handle_vision(self, source: Union[str, BytesIO]) -> bool:
Checks if the loader can handle the source in vision mode.
Args:
source: Either a file path (str) or a BytesIO stream
source: Either a file path (str), URL, or a BytesIO stream
Returns:
bool: True if the loader can handle the source in vision mode
"""
try:
if isinstance(source, str):
if self._is_url(source):
return True # URLs are always supported in vision mode
ext = get_file_extension(source).lower()
return ext in ['pdf', 'jpg', 'jpeg', 'png', 'tiff', 'bmp']
elif isinstance(source, BytesIO):
Expand Down Expand Up @@ -210,4 +243,99 @@ def can_handle_paginate(self, source: Union[str, BytesIO]) -> bool:
# List of extensions that support pagination
return ext in ['pdf']
except Exception:
return False
return False

@staticmethod
def _check_playwright_dependencies():
"""
Check if the playwright dependency is installed.
Raises:
ImportError: If playwright is not installed.
"""
try:
from playwright.sync_api import sync_playwright
except ImportError:
raise ImportError(
"You are using vision with url. You need to install playwright."
"`pip install playwright` and run `playwright install`."
)

def _capture_screenshot_from_url(self, url: str) -> bytes:
"""
Captures a full-page screenshot of a URL using Playwright.
Args:
url: The URL to capture
Returns:
bytes: The screenshot image data
"""
# Optional: Check if playwright is installed before attempting to use it.
self._check_playwright_dependencies()

from playwright.sync_api import sync_playwright # Import after the dependency check

with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
page = browser.new_page()

try:
# Navigate to URL
page.goto(url, wait_until='networkidle')

# Optional: Handle cookie consent popups (customize selectors as needed)
try:
page.click('button:has-text("Accept")', timeout=10000)
except Exception:
pass # Ignore if no cookie banner is found

# Wait for content to load with the configurable timeout
page.wait_for_timeout(self.screenshot_timeout)

# Capture full page screenshot
screenshot = page.screenshot(full_page=True)

return screenshot

finally:
browser.close()

def _split_image_vertically(self, img: Image.Image, chunk_height: int = 1000) -> List[bytes]:
"""
Splits a tall PIL Image into vertical chunks of `chunk_height`.
Returns a list of bytes in PNG format, in top-to-bottom order.
Args:
img: PIL Image to split
chunk_height: Height of each chunk in pixels
Returns:
List of PNG-encoded bytes for each chunk
"""
width, height = img.size
num_chunks = math.ceil(height / chunk_height)

chunks_bytes = []
for i in range(num_chunks):
top = i * chunk_height
bottom = min((i + 1) * chunk_height, height)
crop_box = (0, top, width, bottom)

# Crop the chunk
chunk_img = img.crop(crop_box)

# Convert chunk to bytes
chunk_bytes = io.BytesIO()
chunk_img.save(chunk_bytes, format="PNG", optimize=True)
chunk_bytes.seek(0)
chunks_bytes.append(chunk_bytes.read())

return chunks_bytes

def _is_url(self, source: str) -> bool:
"""Check if the source string is a URL."""
try:
result = urlparse(source)
return bool(result.scheme and result.netloc)
except:
return False
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __post_init__(self):
class DocumentLoaderBeautifulSoup(CachedDocumentLoader):
"""Loader that uses BeautifulSoup4 to load HTML content."""

SUPPORTED_FORMATS = ['html', 'htm']
SUPPORTED_FORMATS = [
'html', 'htm', 'url' # Add URL support
]

def __init__(
self,
Expand Down Expand Up @@ -257,9 +259,7 @@ def load(self, source: Union[str, BytesIO]) -> List[Dict[str, Any]]:
raise ValueError(f"Error loading HTML content: {str(e)}")

def can_handle(self, source: Union[str, BytesIO]) -> bool:
"""Check if the loader can handle this source."""
if isinstance(source, BytesIO):
"""Override to add URL support."""
if isinstance(source, str) and self._is_url(source):
return True
if self._is_url(source):
return True
return get_file_extension(source) in self.SUPPORTED_FORMATS
return super().can_handle(source)
27 changes: 18 additions & 9 deletions extract_thinker/document_loader/document_loader_docling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from io import BytesIO
from typing import Any, Dict, List, Union, Optional
from dataclasses import dataclass, field
from urllib.parse import urlparse

from cachetools import cachedmethod
from cachetools.keys import hashkey
Expand Down Expand Up @@ -120,7 +121,9 @@ class DocumentLoaderDocling(CachedDocumentLoader):
# XML (including PubMed .nxml)
"xml", "nxml",
# Plain text
"txt"
"txt",
# URL support
"url"
]

def __init__(
Expand Down Expand Up @@ -212,37 +215,43 @@ def can_handle(self, source: Union[str, BytesIO]) -> bool:
self.vision_mode
))
def load(self, source: Union[str, BytesIO]) -> List[Dict[str, Any]]:
from docling.document_converter import ConversionResult
"""
Load and parse the document using Docling.
Returns:
A list of dictionaries, each representing a "page" with:
- "content": text from that page
- "image": optional image bytes if vision_mode is True
- "markdown": Markdown string of that page
"""
if not self.can_handle(source):
raise ValueError(f"Cannot handle source: {source}")

# Convert the source to a docling "ConversionResult"
conv_result = self._docling_convert(source)

test = conv_result.document.export_to_markdown()
print(test)
conv_result: ConversionResult = self._docling_convert(source)

# Build the output list of page data
# If the source is a URL, return a single page with all the content.
if isinstance(source, str) and self._is_url(source):
content = conv_result.document.export_to_markdown()
print(content) # Log the exported markdown, if needed
page_output = {"content": content, "image": None}
# Handle image extraction if vision_mode is enabled
if self.vision_mode:
images_dict = self.convert_to_images(source)
page_output["images"] = images_dict.get(0)
return [page_output]

# Build the output list of page data for non-URL sources
pages_output = []
for p in conv_result.pages:
page_dict = {
"content": conv_result.document.export_to_markdown(page_no=p.page_no+1),
"image": None
}

# Handle image extraction if vision_mode is enabled
if self.vision_mode:
images_dict = self.convert_to_images(source)
page_dict["image"] = images_dict.get(p.page_no)

pages_output.append(page_dict)

# Fallback for documents without explicit pages
Expand Down
Loading

0 comments on commit a3e052d

Please sign in to comment.