|
| 1 | +import os |
| 2 | +import json |
| 3 | +from dotenv import load_dotenv |
| 4 | +from attr import field, define |
| 5 | +from ..io import DocumentReader |
| 6 | +from cloudpathlib import AnyPath, S3Client, GSClient |
| 7 | +from ..parsers import TesseractTextOCR, AwsTextOCR, GcpTextOCR |
| 8 | + |
| 9 | +__all__ = ['TextPipeline'] |
| 10 | + |
| 11 | + |
| 12 | +@define |
| 13 | +class TextPipeline: |
| 14 | + source = field() |
| 15 | + destination = field() |
| 16 | + parser_type = field() # aws, gcp, tesseract |
| 17 | + credentials = field(default=None) # {'aws': , 'gcp': } |
| 18 | + |
| 19 | + def __attrs_post_init__(self): |
| 20 | + self.source = AnyPath( |
| 21 | + self.source, client=self._get_client(self.source)) |
| 22 | + self.destination = AnyPath( |
| 23 | + self.destination, client=self._get_client(self.destination)) |
| 24 | + |
| 25 | + def process_data(self): |
| 26 | + if self.source.is_dir(): |
| 27 | + for file in self.source.iterdir(): |
| 28 | + try: # To remove try catch need write file type validation code |
| 29 | + result = self._process_file(file) |
| 30 | + file_name = '.'.join(file.name.split(".")[:-1]) |
| 31 | + file_name = f"{file_name}_{self.parser_type}.json" |
| 32 | + save_path = self.destination.joinpath(file_name) |
| 33 | + save_path.write_text(json.dumps(result)) |
| 34 | + except Exception as ex: |
| 35 | + print(ex) |
| 36 | + continue |
| 37 | + |
| 38 | + else: |
| 39 | + result = self._process_file(self.source) |
| 40 | + self.destination.write_text(json.dumps(result)) |
| 41 | + |
| 42 | + def _process_file(self, file): |
| 43 | + reader = self._get_reader(file._str) |
| 44 | + if self.parser_type == 'aws': |
| 45 | + parser = AwsTextOCR(reader, env_file=self.credentials['aws']) |
| 46 | + |
| 47 | + elif self.parser_type == 'gcp': |
| 48 | + parser = GcpTextOCR(reader, env_file=self.credentials['gcp']) |
| 49 | + |
| 50 | + elif self.parser_type == 'tesseract': |
| 51 | + parser = TesseractTextOCR(reader) |
| 52 | + |
| 53 | + else: |
| 54 | + raise ValueError( |
| 55 | + "Parser type not supported, currently only 'aws', 'gcp' and 'tesseract' are supported") |
| 56 | + return parser.parse |
| 57 | + |
| 58 | + def _get_reader(self, file): |
| 59 | + storage_type = self._get_storage_type(file) |
| 60 | + if storage_type == "gs": |
| 61 | + reader = DocumentReader(file, self.credentials['gs']) |
| 62 | + elif storage_type == 's3': |
| 63 | + reader = DocumentReader(file, self.credentials['aws']) |
| 64 | + else: |
| 65 | + reader = DocumentReader(file) |
| 66 | + return reader |
| 67 | + |
| 68 | + def _get_client(self, file): |
| 69 | + storage_type = self._get_storage_type(file) |
| 70 | + if storage_type == "gs" and self.credentials['gcp']: |
| 71 | + client = GSClient(application_credentials=self.credentials['gcp']) |
| 72 | + |
| 73 | + elif storage_type == 's3' and self.credentials['aws']: |
| 74 | + load_dotenv(self.credentials['aws']) |
| 75 | + client = S3Client(aws_access_key_id=os.getenv( |
| 76 | + 'aws_access_key_id'), aws_secret_access_key=os.getenv('aws_secret_access_key')) |
| 77 | + else: |
| 78 | + client = None |
| 79 | + return client |
| 80 | + |
| 81 | + def _get_storage_type(self, path): |
| 82 | + storage_type = None |
| 83 | + if path.startswith("gs://"): |
| 84 | + storage_type = 'gs' |
| 85 | + elif path.startswith("s3://"): |
| 86 | + storage_type = 's3' |
| 87 | + else: |
| 88 | + storage_type = 'local' |
| 89 | + return storage_type |
0 commit comments