Skip to content

Commit 8ebc024

Browse files
committed
adds core text pipeline
1 parent 70166d2 commit 8ebc024

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

ocrpy/pipelines/text_pipeline.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import json
3+
from dotenv import load_dotenv
4+
from attr import field, define
5+
from ..io import DocumentReader
6+
from cloudpathlib import AnyPath, S3Client, GSClient
7+
from ..parsers import TesseractTextOCR, AwsTextOCR, GcpTextOCR
8+
9+
__all__ = ['TextPipeline']
10+
11+
12+
@define
13+
class TextPipeline:
14+
source = field()
15+
destination = field()
16+
parser_type = field() # aws, gcp, tesseract
17+
credentials = field(default=None) # {'aws': , 'gcp': }
18+
19+
def __attrs_post_init__(self):
20+
self.source = AnyPath(
21+
self.source, client=self._get_client(self.source))
22+
self.destination = AnyPath(
23+
self.destination, client=self._get_client(self.destination))
24+
25+
def process_data(self):
26+
if self.source.is_dir():
27+
for file in self.source.iterdir():
28+
try: # To remove try catch need write file type validation code
29+
result = self._process_file(file)
30+
file_name = '.'.join(file.name.split(".")[:-1])
31+
file_name = f"{file_name}_{self.parser_type}.json"
32+
save_path = self.destination.joinpath(file_name)
33+
save_path.write_text(json.dumps(result))
34+
except Exception as ex:
35+
print(ex)
36+
continue
37+
38+
else:
39+
result = self._process_file(self.source)
40+
self.destination.write_text(json.dumps(result))
41+
42+
def _process_file(self, file):
43+
reader = self._get_reader(file._str)
44+
if self.parser_type == 'aws':
45+
parser = AwsTextOCR(reader, env_file=self.credentials['aws'])
46+
47+
elif self.parser_type == 'gcp':
48+
parser = GcpTextOCR(reader, env_file=self.credentials['gcp'])
49+
50+
elif self.parser_type == 'tesseract':
51+
parser = TesseractTextOCR(reader)
52+
53+
else:
54+
raise ValueError(
55+
"Parser type not supported, currently only 'aws', 'gcp' and 'tesseract' are supported")
56+
return parser.parse
57+
58+
def _get_reader(self, file):
59+
storage_type = self._get_storage_type(file)
60+
if storage_type == "gs":
61+
reader = DocumentReader(file, self.credentials['gs'])
62+
elif storage_type == 's3':
63+
reader = DocumentReader(file, self.credentials['aws'])
64+
else:
65+
reader = DocumentReader(file)
66+
return reader
67+
68+
def _get_client(self, file):
69+
storage_type = self._get_storage_type(file)
70+
if storage_type == "gs" and self.credentials['gcp']:
71+
client = GSClient(application_credentials=self.credentials['gcp'])
72+
73+
elif storage_type == 's3' and self.credentials['aws']:
74+
load_dotenv(self.credentials['aws'])
75+
client = S3Client(aws_access_key_id=os.getenv(
76+
'aws_access_key_id'), aws_secret_access_key=os.getenv('aws_secret_access_key'))
77+
else:
78+
client = None
79+
return client
80+
81+
def _get_storage_type(self, path):
82+
storage_type = None
83+
if path.startswith("gs://"):
84+
storage_type = 'gs'
85+
elif path.startswith("s3://"):
86+
storage_type = 's3'
87+
else:
88+
storage_type = 'local'
89+
return storage_type

0 commit comments

Comments
 (0)