Skip to content

Commit 0395464

Browse files
authored
Merge pull request #9 from Zipstack/fix/test-case-outputs-updated
fix: Updated test case outputs and fuzzy assertion
2 parents 2929529 + e30178e commit 0395464

8 files changed

+230
-236
lines changed

src/unstract/llmwhisperer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.21.0"
1+
__version__ = "0.22.0"
22

33
from .client import LLMWhispererClient # noqa: F401
44

tests/client_test.py

+21-73
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import logging
22
import os
3-
import unittest
3+
from difflib import SequenceMatcher, unified_diff
44
from pathlib import Path
55

66
import pytest
7-
import requests
8-
9-
from unstract.llmwhisperer import LLMWhispererClient
107

118
logger = logging.getLogger(__name__)
129

@@ -23,9 +20,7 @@ def test_get_usage_info(client):
2320
"subscription_plan",
2421
"today_page_count",
2522
]
26-
assert set(usage_info.keys()) == set(
27-
expected_keys
28-
), f"usage_info {usage_info} does not contain the expected keys"
23+
assert set(usage_info.keys()) == set(expected_keys), f"usage_info {usage_info} does not contain the expected keys"
2924

3025

3126
@pytest.mark.parametrize(
@@ -41,85 +36,38 @@ def test_get_usage_info(client):
4136
)
4237
def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
4338
file_path = os.path.join(data_dir, input_file)
44-
response = client.whisper(
39+
whisper_result = client.whisper(
4540
processing_mode=processing_mode,
4641
output_mode=output_mode,
4742
file_path=file_path,
4843
timeout=200,
4944
)
50-
logger.debug(response)
45+
logger.debug(whisper_result)
5146

5247
exp_basename = f"{Path(input_file).stem}.{processing_mode}.{output_mode}.txt"
5348
exp_file = os.path.join(data_dir, "expected", exp_basename)
54-
with open(exp_file, encoding="utf-8") as f:
55-
exp = f.read()
5649

57-
assert isinstance(response, dict)
58-
assert response["status_code"] == 200
59-
assert response["extracted_text"] == exp
50+
assert_extracted_text(exp_file, whisper_result, processing_mode, output_mode)
6051

6152

62-
# TODO: Review and port to pytest based tests
63-
class TestLLMWhispererClient(unittest.TestCase):
64-
@unittest.skip("Skipping test_whisper")
65-
def test_whisper(self):
66-
client = LLMWhispererClient()
67-
# response = client.whisper(
68-
# url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
69-
# )
70-
response = client.whisper(
71-
file_path="test_data/restaurant_invoice_photo.pdf",
72-
timeout=200,
73-
store_metadata_for_highlighting=True,
74-
)
75-
print(response)
76-
# self.assertIsInstance(response, dict)
53+
def assert_extracted_text(file_path, whisper_result, mode, output_mode):
54+
with open(file_path, encoding="utf-8") as f:
55+
exp = f.read()
7756

78-
# @unittest.skip("Skipping test_whisper")
79-
def test_whisper_stream(self):
80-
client = LLMWhispererClient()
81-
download_url = (
82-
"https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83-
)
84-
# Create a stream of download_url and pass it to whisper
85-
response_download = requests.get(download_url, stream=True)
86-
response_download.raise_for_status()
87-
response = client.whisper(
88-
stream=response_download.iter_content(chunk_size=1024),
89-
timeout=200,
90-
store_metadata_for_highlighting=True,
91-
)
92-
print(response)
93-
# self.assertIsInstance(response, dict)
57+
assert isinstance(whisper_result, dict)
58+
assert whisper_result["status_code"] == 200
9459

95-
@unittest.skip("Skipping test_whisper_status")
96-
def test_whisper_status(self):
97-
client = LLMWhispererClient()
98-
response = client.whisper_status(
99-
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100-
)
101-
logger.info(response)
102-
self.assertIsInstance(response, dict)
60+
# For OCR based processing
61+
threshold = 0.97
10362

104-
@unittest.skip("Skipping test_whisper_retrieve")
105-
def test_whisper_retrieve(self):
106-
client = LLMWhispererClient()
107-
response = client.whisper_retrieve(
108-
whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109-
)
110-
logger.info(response)
111-
self.assertIsInstance(response, dict)
63+
# For text based processing
64+
if mode == "native_text" and output_mode == "text":
65+
threshold = 0.99
66+
extracted_text = whisper_result["extracted_text"]
67+
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
11268

113-
@unittest.skip("Skipping test_whisper_highlight_data")
114-
def test_whisper_highlight_data(self):
115-
client = LLMWhispererClient()
116-
response = client.highlight_data(
117-
whisper_hash="9924d865|5f1d285a7cf18d203de7af1a1abb0a3a",
118-
search_text="Indiranagar",
69+
if similarity < threshold:
70+
diff = "\n".join(
71+
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
11972
)
120-
logger.info(response)
121-
self.assertIsInstance(response, dict)
122-
123-
124-
if __name__ == "__main__":
125-
unittest.main()
73+
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")

0 commit comments

Comments
 (0)