Skip to content

Commit 27a7b96

Browse files
authored
Merge pull request #7 from maxent-ai/pipeline
Pipeline
2 parents bdd094e + 0a5646d commit 27a7b96

12 files changed

+216
-99
lines changed

ocrpy/io/reader.py

+46-36
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,59 @@
1-
import io
21
import os
3-
import pdf2image
2+
from io import BytesIO
3+
from ..utils import LOGGER
44
from dotenv import load_dotenv
55
from attr import define, field
6+
from pdf2image import convert_from_bytes
7+
from ..utils import FileTypeNotSupported
8+
from typing import Union, Generator, ByteString
9+
from ..utils import guess_extension, guess_storage
610
from cloudpathlib import S3Client, GSClient, AnyPath
711

812
__all__ = ['DocumentReader']
913

14+
# TO-DO: Add logging and improve the error handling
15+
1016

1117
@define
1218
class DocumentReader:
1319
"""
14-
Read an image file from a given location and returns a byte array.
20+
Reads an image or a pdf file from a local or remote location.
21+
Note: Currently supports Google Storage and Amazon S3 Remote Files.
22+
23+
Attributes
24+
----------
25+
file : str
26+
The path to the file to be read.
27+
credentials : str
28+
The path to the credentials file.
29+
Note:
30+
If the Remote storage is AWS S3, the credentials file must be in the .env format.
31+
If the Remote storage is Google Storage, the credentials file must be in the .json format.
1532
"""
16-
file = field()
17-
credentials = field(default=None)
33+
file: str = field()
34+
credentials: str = field(default=None)
35+
storage_type = field(default=None, init=False)
1836

19-
def read(self):
20-
file_type = self.get_file_type()
21-
if file_type == 'image':
22-
return self._read_image(self.file)
23-
elif file_type == 'pdf':
24-
return self._read_pdf(self.file)
25-
else:
26-
raise ValueError("File type not supported")
37+
def __attrs_post_init__(self):
38+
self.storage_type = guess_storage(self.file)
2739

28-
def get_file_type(self):
29-
if self.file.endswith(".png") or self.file.endswith(".jpg"):
30-
file_type = "image"
31-
elif self.file.endswith(".pdf"):
32-
file_type = "pdf"
33-
else:
34-
file_type = "unknown"
35-
return file_type
40+
def read(self) -> Union[Generator, ByteString]:
41+
"""
42+
Reads the file from a local or remote location and
43+
returns the data in byte-string for an image or as a
44+
generator of byte-strings for a pdf.
3645
37-
def get_storage_type(self):
38-
storage_type = None
39-
if self.file.startswith("gs://"):
40-
storage_type = 'gs'
41-
elif self.file.startswith("s3://"):
42-
storage_type = 's3'
43-
else:
44-
storage_type = 'local'
45-
return storage_type
46+
Returns
47+
-------
48+
data : Union[bytes, List[bytes]]
49+
The data in byte-string for an image or as a
50+
generator of byte-strings for a pdf.
51+
"""
52+
53+
file_type = guess_extension(self.file)
54+
reader_methods = {'IMAGE': self._read_image, 'PDF': self._read_pdf}
55+
return reader_methods[file_type](self.file) if file_type in reader_methods else FileTypeNotSupported(
56+
f"""We failed to understand the file type of {self.file}. The supported file-types are .png, .jpg or .pdf files. Please check the file type and try again.""")
4657

4758
def _read_image(self, file):
4859
return self._read(file)
@@ -58,22 +69,21 @@ def _read(self, file):
5869
return file_data.read_bytes()
5970

6071
def _get_client(self, file):
61-
storage_type = self.get_storage_type()
62-
if storage_type == "gs" and self.credentials:
72+
storage_type = self.storage_type
73+
if storage_type == "GS" and self.credentials:
6374
client = GSClient(application_credentials=self.credentials)
6475

65-
elif storage_type == 's3' and self.credentials:
76+
elif storage_type == 'S3' and self.credentials:
6677
load_dotenv(self.credentials)
6778
client = S3Client(aws_access_key_id=os.getenv(
6879
'aws_access_key_id'), aws_secret_access_key=os.getenv('aws_secret_access_key'))
6980
else:
7081
client = None
71-
7282
return client
7383

7484
def _bytes_to_images(self, data):
75-
images = pdf2image.convert_from_bytes(data)
85+
images = convert_from_bytes(data)
7686
for image in images:
77-
buf = io.BytesIO()
87+
buf = BytesIO()
7888
image.save(buf, format='PNG')
7989
yield buf.getvalue()

ocrpy/parsers/core.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ class AbstractTextOCR:
3333
Abstract class for Text OCR backends.
3434
"""
3535
reader: Any = field()
36+
credentials: str = field()
3637

37-
@abc.abstractproperty
38+
@abc.abstractmethod
3839
def parse(self):
3940
return NotImplemented
4041

ocrpy/parsers/text/aws_text.py

+46-27
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import os
2-
import boto3
32
import time
4-
from cloudpathlib import AnyPath
3+
import boto3
54
from dotenv import load_dotenv
65
from attr import define, field
7-
from typing import List, Dict, Any
8-
from ...utils.errors import NotSupportedError
6+
from cloudpathlib import AnyPath
7+
from ...utils.exceptions import AttributeNotSupported
8+
from typing import List, Dict, Any, Union, Generator, ByteString
99
from ..core import AbstractTextOCR, AbstractLineSegmenter, AbstractBlockSegmenter
1010

1111
__all__ = ['AwsTextOCR']
1212

13+
# TO-DO: Add logging and improve the error handling
14+
1315

1416
def aws_region_extractor(block):
1517
x1, x2 = block['Geometry']['BoundingBox']['Left'], block['Geometry']['BoundingBox']['Left'] + \
@@ -28,7 +30,7 @@ def aws_token_formator(token):
2830
return token
2931

3032

31-
def is_job_complete(client, job_id):
33+
def _job_status(client, job_id):
3234
time.sleep(1)
3335
response = client.get_document_text_detection(JobId=job_id)
3436
status = response["JobStatus"]
@@ -40,7 +42,7 @@ def is_job_complete(client, job_id):
4042
return status
4143

4244

43-
def get_job_results(client, job_id):
45+
def _fetch_job_result(client, job_id):
4446
pages = []
4547
response = client.get_document_text_detection(JobId=job_id)
4648
pages.append(response)
@@ -101,27 +103,45 @@ class AwsBlockSegmenter(AbstractBlockSegmenter):
101103

102104
@property
103105
def blocks(self):
104-
raise NotSupportedError(
106+
raise AttributeNotSupported(
105107
"AWS Backend does not support block segmentation yet.")
106108

107109

108110
@define
109111
class AwsTextOCR(AbstractTextOCR):
110-
env_file = field(default=None)
111-
textract = field(repr=False, init=False)
112-
document = field(default=None, repr=False)
112+
"""
113+
AWS Textract OCR Engine
114+
115+
Attributes
116+
----------
117+
reader: Any
118+
Reader object that can be used to read the document.
119+
credentials : str
120+
Path to credentials file.
121+
Note: The credentials file must be in .env format.
122+
"""
123+
_client: boto3.client = field(repr=False, init=False)
124+
_document: Union[Generator, ByteString] = field(
125+
default=None, repr=False, init=False)
113126

114127
def __attrs_post_init__(self):
115-
if self.env_file:
116-
load_dotenv(self.env_file)
128+
if self.credentials:
129+
load_dotenv(self.credentials)
117130
region = os.getenv('region_name')
118131
access_key = os.getenv('aws_access_key_id')
119132
secret_key = os.getenv('aws_secret_access_key')
120-
self.textract = boto3.client('textract', region_name=region,
121-
aws_access_key_id=access_key, aws_secret_access_key=secret_key)
122-
123-
@property
124-
def parse(self):
133+
self._client = boto3.client('textract', region_name=region,
134+
aws_access_key_id=access_key, aws_secret_access_key=secret_key)
135+
136+
def parse(self) -> Dict[int, Dict]:
137+
"""
138+
Parses the document and returns the ocr data as a dictionary of pages along with additional metadata.
139+
140+
Returns
141+
-------
142+
parsed_data : dict
143+
Dictionary of pages.
144+
"""
125145
return self._process_data()
126146

127147
def _process_data(self):
@@ -130,34 +150,33 @@ def _process_data(self):
130150
if not isinstance(ocr, list):
131151
ocr = [ocr]
132152
for index, page in enumerate(ocr):
133-
print("Processing page {}".format(index))
134153
data = dict(text=self._get_text(page), lines=self._get_lines(
135154
page), blocks=self._get_blocks(page), tokens=self._get_tokens(page))
136155
result[index] = data
137156
return result
138157

139158
def _get_ocr(self):
140-
storage_type = self.reader.get_storage_type()
159+
storage_type = self.reader.storage_type
141160

142161
if storage_type == 's3':
143162
path = AnyPath(self.reader.file)
144163

145-
response = self.textract.start_document_text_detection(DocumentLocation={
164+
response = self._client.start_document_text_detection(DocumentLocation={
146165
'S3Object': {
147166
'Bucket': path.bucket,
148167
'Name': path.key
149168
}})
150169
job_id = response['JobId']
151-
status = is_job_complete(self.textract, job_id)
152-
ocr = get_job_results(self.textract, job_id)
170+
status = _job_status(self.textract, job_id)
171+
ocr = _fetch_job_result(self.textract, job_id)
153172

154173
else:
155-
self.document = self.reader.read()
156-
if isinstance(self.document, bytes):
157-
self.document = [self.document]
174+
self._document = self.reader.read()
175+
if isinstance(self._document, bytes):
176+
self._document = [self._document]
158177
ocr = []
159-
for document in self.document:
160-
result = self.textract.detect_document_text(
178+
for document in self._document:
179+
result = self._client.detect_document_text(
161180
Document={'Bytes': document})
162181
ocr.append(result)
163182
return ocr

ocrpy/parsers/text/gcp_text.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from google.cloud import vision
44
from typing import List, Dict, Any
55
from google.oauth2 import service_account
6-
from ...utils.errors import NotSupportedError
6+
from ...utils.exceptions import AttributeNotSupported
77
from ..core import AbstractTextOCR, AbstractLineSegmenter, AbstractBlockSegmenter
88

99
__all__ = ['GcpTextOCR']
1010

11+
# TO-DO: Add logging and improve the error handling
1112

1213
def gcp_region_extractor(block):
1314
x_points = [v.x for v in block]
@@ -74,35 +75,52 @@ class GCPLineSegmenter(AbstractLineSegmenter):
7475

7576
@property
7677
def lines(self):
77-
NotSupportedError("GCP Backend does not support line segmentation yet.")
78+
AttributeNotSupported("GCP Backend does not support line segmentation yet.")
7879

7980
@define
8081
class GcpTextOCR(AbstractTextOCR):
81-
env_file = field(default=None)
82-
client = field(repr=False, init=False)
83-
document = field(default=None, repr=False)
82+
"""
83+
Google Cloud Vision OCR Engine
84+
85+
Attributes
86+
----------
87+
reader : Any
88+
Reader object that can be used to read the document.
89+
credentials : str
90+
Path to credentials file.
91+
Note: The credentials file must be in .json format.
92+
"""
93+
_client = field(repr=False, init=False)
94+
_document = field(default=None, repr=False, init=False)
8495

8596
def __attrs_post_init__(self):
86-
if self.env_file:
87-
cred = service_account.Credentials.from_service_account_file(self.env_file)
88-
self.client = vision.ImageAnnotatorClient(credentials=cred)
97+
if self.credentials:
98+
cred = service_account.Credentials.from_service_account_file(self.credentials)
99+
self._client = vision.ImageAnnotatorClient(credentials=cred)
89100
else:
90-
self.client = vision.ImageAnnotatorClient()
101+
self._client = vision.ImageAnnotatorClient()
91102

92-
self.document = self.reader.read()
103+
self._document = self.reader.read()
93104

94-
@property
95105
def parse(self):
106+
"""
107+
Parses the document and returns the ocr data as a dictionary of pages along with additional metadata.
108+
109+
Returns
110+
-------
111+
parsed_data : dict
112+
Dictionary of pages.
113+
"""
96114
return self._process_data()
97115

98116
def _process_data(self):
99117
is_image = False
100-
if isinstance(self.document, bytes):
101-
self.document = [self.document]
118+
if isinstance(self._document, bytes):
119+
self._document = [self._document]
102120
is_image = True
103121

104122
result = {}
105-
for index, document in enumerate(self.document):
123+
for index, document in enumerate(self._document):
106124

107125
ocr = self._get_ocr(document)
108126
blocks = self._get_blocks(ocr)
@@ -147,7 +165,7 @@ def _get_text(self, ocr):
147165

148166
def _get_ocr(self, image):
149167
image = vision.types.Image(content=image)
150-
ocr = self.client.document_text_detection(image=image).full_text_annotation
168+
ocr = self._client.document_text_detection(image=image).full_text_annotation
151169
return ocr
152170

153171

0 commit comments

Comments
 (0)