Skip to content

Commit c497fd3

Browse files
authored
Merge pull request #17 from Zipstack/update-client_for_status_api
Updated the status API check based on the changes at llm whisperer side
2 parents 82baf42 + ba6ef48 commit c497fd3

File tree

4 files changed

+121
-69
lines changed

4 files changed

+121
-69
lines changed

src/unstract/llmwhisperer/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.0.0"
1+
__version__ = "2.1.0"
22

33
from .client import LLMWhispererClient # noqa: F401
44
from .client_v2 import LLMWhispererClientV2 # noqa: F401

src/unstract/llmwhisperer/client_v2.py

+30-38
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ class LLMWhispererClientV2:
6262
client's activities and errors.
6363
"""
6464

65-
formatter = logging.Formatter(
66-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
67-
)
65+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
6866
logger = logging.getLogger(__name__)
6967
log_stream_handler = logging.StreamHandler()
7068
log_stream_handler.setFormatter(formatter)
@@ -108,7 +106,6 @@ def __init__(
108106
self.logger.setLevel(logging.ERROR)
109107
self.logger.setLevel(logging_level)
110108
self.logger.debug("logging_level set to %s", logging_level)
111-
112109
if base_url == "":
113110
self.base_url = os.getenv("LLMWHISPERER_BASE_URL_V2", BASE_URL_V2)
114111
else:
@@ -121,6 +118,15 @@ def __init__(
121118
self.api_key = api_key
122119

123120
self.headers = {"unstract-key": self.api_key}
121+
# For test purpose
122+
# self.headers = {
123+
# "Subscription-Id": "python-client",
124+
# "Subscription-Name": "python-client",
125+
# "User-Id": "python-client-user",
126+
# "Product-Id": "python-client-product",
127+
# "Product-Name": "python-client-product",
128+
# "Start-Date": "2024-07-09",
129+
# }
124130

125131
def get_usage_info(self) -> dict:
126132
"""Retrieves the usage information of the LLMWhisperer API.
@@ -283,9 +289,7 @@ def generate():
283289
)
284290
else:
285291
params["url_in_post"] = True
286-
req = requests.Request(
287-
"POST", api_url, params=params, headers=self.headers, data=url
288-
)
292+
req = requests.Request("POST", api_url, params=params, headers=self.headers, data=url)
289293
prepared = req.prepare()
290294
s = requests.Session()
291295
response = s.send(prepared, timeout=wait_timeout, stream=should_stream)
@@ -310,42 +314,30 @@ def generate():
310314
message["message"] = "Whisper client operation failed"
311315
message["extraction"] = {}
312316
return message
317+
if status["status"] == "accepted":
318+
self.logger.debug(f'Whisper-hash:{whisper_hash} | STATUS: {status["status"]}...')
313319
if status["status"] == "processing":
314-
self.logger.debug(
315-
f"Whisper-hash:{whisper_hash} | STATUS: processing..."
316-
)
317-
elif status["status"] == "delivered":
318-
self.logger.debug(
319-
f"Whisper-hash:{whisper_hash} | STATUS: Already delivered!"
320-
)
321-
raise LLMWhispererClientException(
322-
{
323-
"status_code": -1,
324-
"message": "Whisper operation already delivered",
325-
}
326-
)
327-
elif status["status"] == "unknown":
328-
self.logger.debug(
329-
f"Whisper-hash:{whisper_hash} | STATUS: unknown..."
330-
)
331-
raise LLMWhispererClientException(
332-
{
333-
"status_code": -1,
334-
"message": "Whisper operation status unknown",
335-
}
336-
)
337-
elif status["status"] == "failed":
338-
self.logger.debug(
339-
f"Whisper-hash:{whisper_hash} | STATUS: failed..."
340-
)
320+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: processing...")
321+
322+
elif status["status"] == "error":
323+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: failed...")
324+
self.logger.error(f'Whisper-hash:{whisper_hash} | STATUS: failed with {status["message"]}')
325+
message["status_code"] = -1
326+
message["message"] = status["message"]
327+
message["status"] = "error"
328+
message["extraction"] = {}
329+
return message
330+
elif "error" in status["status"]:
331+
# for backward compatabity
332+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: failed...")
333+
self.logger.error(f'Whisper-hash:{whisper_hash} | STATUS: failed with {status["status"]}')
341334
message["status_code"] = -1
342-
message["message"] = "Whisper operation failed"
335+
message["message"] = status["status"]
336+
message["status"] = "error"
343337
message["extraction"] = {}
344338
return message
345339
elif status["status"] == "processed":
346-
self.logger.debug(
347-
f"Whisper-hash:{whisper_hash} | STATUS: processed!"
348-
)
340+
self.logger.debug(f"Whisper-hash:{whisper_hash} | STATUS: processed!")
349341
resultx = self.whisper_retrieve(whisper_hash=whisper_hash)
350342
if resultx["status_code"] == 200:
351343
message["status_code"] = 200

tests/integration/client_v2_test.py

+89-30
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_get_usage_info(client_v2):
1818
"current_page_count_form",
1919
"current_page_count_high_quality",
2020
"current_page_count_native_text",
21+
"current_page_count_excel",
2122
"daily_quota",
2223
"monthly_quota",
2324
"overage_page_count",
@@ -44,7 +45,10 @@ def test_get_usage_info(client_v2):
4445
def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
4546
file_path = os.path.join(data_dir, input_file)
4647
whisper_result = client_v2.whisper(
47-
mode=mode, output_mode=output_mode, file_path=file_path, wait_for_completion=True
48+
mode=mode,
49+
output_mode=output_mode,
50+
file_path=file_path,
51+
wait_for_completion=True,
4852
)
4953
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")
5054

@@ -54,24 +58,62 @@ def test_whisper_v2(client_v2, data_dir, output_mode, mode, input_file):
5458
assert_extracted_text(exp_file, whisper_result, mode, output_mode)
5559

5660

61+
@pytest.mark.parametrize(
62+
"output_mode, mode, input_file",
63+
[
64+
("layout_preserving", "high_quality", "test.json"),
65+
],
66+
)
67+
def test_whisper_v2_error(client_v2, data_dir, output_mode, mode, input_file):
68+
file_path = os.path.join(data_dir, input_file)
69+
70+
whisper_result = client_v2.whisper(
71+
mode=mode,
72+
output_mode=output_mode,
73+
file_path=file_path,
74+
wait_for_completion=True,
75+
)
76+
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")
77+
78+
assert_error_message(whisper_result)
79+
80+
5781
@pytest.mark.parametrize(
5882
"output_mode, mode, url, input_file, page_count",
5983
[
60-
("layout_preserving", "native_text", "https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
61-
"credit_card.pdf", 7),
62-
("layout_preserving", "low_cost", "https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
63-
"credit_card.pdf", 7),
64-
("layout_preserving", "high_quality", "https://unstractpocstorage.blob.core.windows.net/public/scanned_bill.pdf",
65-
"restaurant_invoice_photo.pdf", 1),
66-
("layout_preserving", "form", "https://unstractpocstorage.blob.core.windows.net/public/scanned_form.pdf",
67-
"handwritten-form.pdf", 1),
68-
]
84+
(
85+
"layout_preserving",
86+
"native_text",
87+
"https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
88+
"credit_card.pdf",
89+
7,
90+
),
91+
(
92+
"layout_preserving",
93+
"low_cost",
94+
"https://unstractpocstorage.blob.core.windows.net/public/Amex.pdf",
95+
"credit_card.pdf",
96+
7,
97+
),
98+
(
99+
"layout_preserving",
100+
"high_quality",
101+
"https://unstractpocstorage.blob.core.windows.net/public/scanned_bill.pdf",
102+
"restaurant_invoice_photo.pdf",
103+
1,
104+
),
105+
(
106+
"layout_preserving",
107+
"form",
108+
"https://unstractpocstorage.blob.core.windows.net/public/scanned_form.pdf",
109+
"handwritten-form.pdf",
110+
1,
111+
),
112+
],
69113
)
70114
def test_whisper_v2_url_in_post(client_v2, data_dir, output_mode, mode, url, input_file, page_count):
71115
usage_before = client_v2.get_usage_info()
72-
whisper_result = client_v2.whisper(
73-
mode=mode, output_mode=output_mode, url=url, wait_for_completion=True
74-
)
116+
whisper_result = client_v2.whisper(mode=mode, output_mode=output_mode, url=url, wait_for_completion=True)
75117
logger.debug(f"Result for '{output_mode}', '{mode}', " f"'{input_file}: {whisper_result}")
76118

77119
exp_basename = f"{Path(input_file).stem}.{mode}.{output_mode}.txt"
@@ -83,6 +125,12 @@ def test_whisper_v2_url_in_post(client_v2, data_dir, output_mode, mode, url, inp
83125
verify_usage(usage_before, usage_after, page_count, mode)
84126

85127

128+
def assert_error_message(whisper_result):
129+
assert isinstance(whisper_result, dict)
130+
assert whisper_result["status"] == "error"
131+
assert "error" in whisper_result["message"]
132+
133+
86134
def assert_extracted_text(file_path, whisper_result, mode, output_mode):
87135
with open(file_path, encoding="utf-8") as f:
88136
exp = f.read()
@@ -91,34 +139,45 @@ def assert_extracted_text(file_path, whisper_result, mode, output_mode):
91139
assert whisper_result["status_code"] == 200
92140

93141
# For OCR based processing
94-
threshold = 0.97
142+
threshold = 0.94
95143

96144
# For text based processing
97145
if mode == "native_text" and output_mode == "text":
98146
threshold = 0.99
147+
elif mode == "low_cost":
148+
threshold = 0.90
99149
extracted_text = whisper_result["extraction"]["result_text"]
100150
similarity = SequenceMatcher(None, extracted_text, exp).ratio()
101151

102152
if similarity < threshold:
103153
diff = "\n".join(
104-
unified_diff(exp.splitlines(), extracted_text.splitlines(), fromfile="Expected", tofile="Extracted")
154+
unified_diff(
155+
exp.splitlines(),
156+
extracted_text.splitlines(),
157+
fromfile="Expected",
158+
tofile="Extracted",
159+
)
105160
)
106-
pytest.fail(f"Texts are not similar enough: {similarity * 100:.2f}% similarity. Diff:\n{diff}")
161+
pytest.fail(f"Diff:\n{diff}.\n Texts are not similar enough: {similarity * 100:.2f}% similarity. ")
107162

108163

109-
def verify_usage(before_extract, after_extract, page_count, mode='form'):
110-
all_modes = ['form', 'high_quality', 'low_cost', 'native_text']
164+
def verify_usage(before_extract, after_extract, page_count, mode="form"):
165+
all_modes = ["form", "high_quality", "low_cost", "native_text"]
111166
all_modes.remove(mode)
112-
assert (after_extract['today_page_count'] == before_extract['today_page_count'] + page_count), \
113-
"today_page_count calculation is wrong"
114-
assert (after_extract['current_page_count'] == before_extract['current_page_count'] + page_count), \
115-
"current_page_count calculation is wrong"
116-
if after_extract['overage_page_count'] > 0:
117-
assert (after_extract['overage_page_count'] == before_extract['overage_page_count'] + page_count), \
118-
"overage_page_count calculation is wrong"
119-
assert (after_extract[f'current_page_count_{mode}'] == before_extract[f'current_page_count_{mode}'] + page_count), \
120-
f"{mode} mode calculation is wrong"
167+
assert (
168+
after_extract["today_page_count"] == before_extract["today_page_count"] + page_count
169+
), "today_page_count calculation is wrong"
170+
assert (
171+
after_extract["current_page_count"] == before_extract["current_page_count"] + page_count
172+
), "current_page_count calculation is wrong"
173+
if after_extract["overage_page_count"] > 0:
174+
assert (
175+
after_extract["overage_page_count"] == before_extract["overage_page_count"] + page_count
176+
), "overage_page_count calculation is wrong"
177+
assert (
178+
after_extract[f"current_page_count_{mode}"] == before_extract[f"current_page_count_{mode}"] + page_count
179+
), f"{mode} mode calculation is wrong"
121180
for i in range(len(all_modes)):
122-
assert (after_extract[f'current_page_count_{all_modes[i]}'] ==
123-
before_extract[f'current_page_count_{all_modes[i]}']), \
124-
f"{all_modes[i]} mode calculation is wrong"
181+
assert (
182+
after_extract[f"current_page_count_{all_modes[i]}"] == before_extract[f"current_page_count_{all_modes[i]}"]
183+
), f"{all_modes[i]} mode calculation is wrong"

tests/test_data/test.json

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"test": "HelloWorld"}

0 commit comments

Comments
 (0)