Skip to content

Commit 768257c

Browse files
Text mode test case assertion updated to be fuzzy with threshold 0.99
1 parent 2625c1d commit 768257c

File tree

1 file changed

+19
-79
lines changed

1 file changed

+19
-79
lines changed

tests/client_test.py

+19-79
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import logging
22
import os
3-
import unittest
43
from difflib import SequenceMatcher, unified_diff
54
from pathlib import Path
65

76
import pytest
8-
import requests
9-
10-
from unstract.llmwhisperer import LLMWhispererClient
117

128
logger = logging.getLogger(__name__)
139

@@ -40,93 +36,37 @@ def test_get_usage_info(client):
4036
)
4137
def test_whisper(client, data_dir, processing_mode, output_mode, input_file):
4238
file_path = os.path.join(data_dir, input_file)
43-
response = client.whisper(
39+
whisper_result = client.whisper(
4440
processing_mode=processing_mode,
4541
output_mode=output_mode,
4642
file_path=file_path,
4743
timeout=200,
4844
)
49-
logger.debug(response)
45+
logger.debug(whisper_result)
5046

5147
exp_basename = f"{Path(input_file).stem}.{processing_mode}.{output_mode}.txt"
5248
exp_file = os.path.join(data_dir, "expected", exp_basename)
53-
with open(exp_file, encoding="utf-8") as f:
54-
exp = f.read()
55-
56-
assert isinstance(response, dict)
57-
assert response["status_code"] == 200
58-
59-
# For text based processing, perform a strict match
60-
if processing_mode == "text" and output_mode == "text":
61-
assert response["extracted_text"] == exp
62-
# For OCR based processing, perform a fuzzy match
63-
else:
64-
extracted_text = response["extracted_text"]
65-
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
66-
threshold = 0.97
67-
68-
if similarity < threshold:
69-
diff = "\n".join(
70-
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
71-
)
72-
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
49+
assert_extracted_text(exp_file, whisper_result, processing_mode, output_mode)
7350

7451

75-
# TODO: Review and port to pytest based tests
76-
class TestLLMWhispererClient(unittest.TestCase):
77-
@unittest.skip("Skipping test_whisper")
78-
def test_whisper(self):
79-
client = LLMWhispererClient()
80-
# response = client.whisper(
81-
# url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
82-
# )
83-
response = client.whisper(
84-
file_path="test_data/restaurant_invoice_photo.pdf",
85-
timeout=200,
86-
store_metadata_for_highlighting=True,
87-
)
88-
print(response)
89-
# self.assertIsInstance(response, dict)
52+
def assert_extracted_text(file_path, whisper_result, mode, output_mode):
53+
with open(file_path, encoding="utf-8") as f:
54+
exp = f.read()
9055

91-
# @unittest.skip("Skipping test_whisper")
92-
def test_whisper_stream(self):
93-
client = LLMWhispererClient()
94-
download_url = "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
95-
# Create a stream of download_url and pass it to whisper
96-
response_download = requests.get(download_url, stream=True)
97-
response_download.raise_for_status()
98-
response = client.whisper(
99-
stream=response_download.iter_content(chunk_size=1024),
100-
timeout=200,
101-
store_metadata_for_highlighting=True,
102-
)
103-
print(response)
104-
# self.assertIsInstance(response, dict)
56+
assert isinstance(whisper_result, dict)
57+
assert whisper_result["status_code"] == 200
10558

106-
@unittest.skip("Skipping test_whisper_status")
107-
def test_whisper_status(self):
108-
client = LLMWhispererClient()
109-
response = client.whisper_status(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
110-
logger.info(response)
111-
self.assertIsInstance(response, dict)
59+
# For OCR based processing
60+
threshold = 0.97
11261

113-
@unittest.skip("Skipping test_whisper_retrieve")
114-
def test_whisper_retrieve(self):
115-
client = LLMWhispererClient()
116-
response = client.whisper_retrieve(whisper_hash="7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a")
117-
logger.info(response)
118-
self.assertIsInstance(response, dict)
62+
# For text based processing
63+
if mode == "native_text" and output_mode == "text":
64+
threshold = 0.99
65+
extracted_text = whisper_result["extracted_text"]
66+
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
11967

120-
@unittest.skip("Skipping test_whisper_highlight_data")
121-
def test_whisper_highlight_data(self):
122-
client = LLMWhispererClient()
123-
response = client.highlight_data(
124-
whisper_hash="9924d865|5f1d285a7cf18d203de7af1a1abb0a3a",
125-
search_text="Indiranagar",
68+
if similarity < threshold:
69+
diff = "\n".join(
70+
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
12671
)
127-
logger.info(response)
128-
self.assertIsInstance(response, dict)
129-
130-
131-
if __name__ == "__main__":
132-
unittest.main()
72+
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")

0 commit comments

Comments
 (0)