1
1
import logging
2
2
import os
3
- import unittest
3
+ from difflib import SequenceMatcher , unified_diff
4
4
from pathlib import Path
5
5
6
6
import pytest
7
- import requests
8
-
9
- from unstract .llmwhisperer import LLMWhispererClient
10
7
11
8
logger = logging .getLogger (__name__ )
12
9
@@ -23,9 +20,7 @@ def test_get_usage_info(client):
23
20
"subscription_plan" ,
24
21
"today_page_count" ,
25
22
]
26
- assert set (usage_info .keys ()) == set (
27
- expected_keys
28
- ), f"usage_info { usage_info } does not contain the expected keys"
23
+ assert set (usage_info .keys ()) == set (expected_keys ), f"usage_info { usage_info } does not contain the expected keys"
29
24
30
25
31
26
@pytest .mark .parametrize (
@@ -41,85 +36,38 @@ def test_get_usage_info(client):
41
36
)
42
37
def test_whisper (client , data_dir , processing_mode , output_mode , input_file ):
43
38
file_path = os .path .join (data_dir , input_file )
44
- response = client .whisper (
39
+ whisper_result = client .whisper (
45
40
processing_mode = processing_mode ,
46
41
output_mode = output_mode ,
47
42
file_path = file_path ,
48
43
timeout = 200 ,
49
44
)
50
- logger .debug (response )
45
+ logger .debug (whisper_result )
51
46
52
47
exp_basename = f"{ Path (input_file ).stem } .{ processing_mode } .{ output_mode } .txt"
53
48
exp_file = os .path .join (data_dir , "expected" , exp_basename )
54
- with open (exp_file , encoding = "utf-8" ) as f :
55
- exp = f .read ()
56
49
57
- assert isinstance (response , dict )
58
- assert response ["status_code" ] == 200
59
- assert response ["extracted_text" ] == exp
50
+ assert_extracted_text (exp_file , whisper_result , processing_mode , output_mode )
60
51
61
52
62
- # TODO: Review and port to pytest based tests
63
- class TestLLMWhispererClient (unittest .TestCase ):
64
- @unittest .skip ("Skipping test_whisper" )
65
- def test_whisper (self ):
66
- client = LLMWhispererClient ()
67
- # response = client.whisper(
68
- # url="https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
69
- # )
70
- response = client .whisper (
71
- file_path = "test_data/restaurant_invoice_photo.pdf" ,
72
- timeout = 200 ,
73
- store_metadata_for_highlighting = True ,
74
- )
75
- print (response )
76
- # self.assertIsInstance(response, dict)
53
+ def assert_extracted_text (file_path , whisper_result , mode , output_mode ):
54
+ with open (file_path , encoding = "utf-8" ) as f :
55
+ exp = f .read ()
77
56
78
- # @unittest.skip("Skipping test_whisper")
79
- def test_whisper_stream (self ):
80
- client = LLMWhispererClient ()
81
- download_url = (
82
- "https://storage.googleapis.com/pandora-static/samples/bill.jpg.pdf"
83
- )
84
- # Create a stream of download_url and pass it to whisper
85
- response_download = requests .get (download_url , stream = True )
86
- response_download .raise_for_status ()
87
- response = client .whisper (
88
- stream = response_download .iter_content (chunk_size = 1024 ),
89
- timeout = 200 ,
90
- store_metadata_for_highlighting = True ,
91
- )
92
- print (response )
93
- # self.assertIsInstance(response, dict)
57
+ assert isinstance (whisper_result , dict )
58
+ assert whisper_result ["status_code" ] == 200
94
59
95
- @unittest .skip ("Skipping test_whisper_status" )
96
- def test_whisper_status (self ):
97
- client = LLMWhispererClient ()
98
- response = client .whisper_status (
99
- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
100
- )
101
- logger .info (response )
102
- self .assertIsInstance (response , dict )
60
+ # For OCR based processing
61
+ threshold = 0.97
103
62
104
- @unittest .skip ("Skipping test_whisper_retrieve" )
105
- def test_whisper_retrieve (self ):
106
- client = LLMWhispererClient ()
107
- response = client .whisper_retrieve (
108
- whisper_hash = "7cfa5cbb|5f1d285a7cf18d203de7af1a1abb0a3a"
109
- )
110
- logger .info (response )
111
- self .assertIsInstance (response , dict )
63
+ # For text based processing
64
+ if mode == "native_text" and output_mode == "text" :
65
+ threshold = 0.99
66
+ extracted_text = whisper_result ["extracted_text" ]
67
+ similarity = SequenceMatcher (None , extracted_text , exp ).ratio ()
112
68
113
- @unittest .skip ("Skipping test_whisper_highlight_data" )
114
- def test_whisper_highlight_data (self ):
115
- client = LLMWhispererClient ()
116
- response = client .highlight_data (
117
- whisper_hash = "9924d865|5f1d285a7cf18d203de7af1a1abb0a3a" ,
118
- search_text = "Indiranagar" ,
69
+ if similarity < threshold :
70
+ diff = "\n " .join (
71
+ unified_diff (exp .splitlines (), extracted_text .splitlines (), fromfile = "Expected" , tofile = "Extracted" )
119
72
)
120
- logger .info (response )
121
- self .assertIsInstance (response , dict )
122
-
123
-
124
- if __name__ == "__main__" :
125
- unittest .main ()
73
+ pytest .fail (f"Texts are not similar enough: { similarity * 100 :.2f} % similarity. Diff:\n { diff } " )
0 commit comments