@@ -24,7 +24,7 @@ class CogStack():
24
24
25
25
def __init__ (self , hosts : List [str ]):
26
26
self .hosts = hosts
27
- self .elastic = None
27
+ self .elastic : elasticsearch . Elasticsearch
28
28
29
29
@classmethod
30
30
def with_basic_auth (cls ,
@@ -138,11 +138,13 @@ def use_api_key_auth(self, api_key: Optional[Dict] = None) -> 'CogStack':
138
138
-------
139
139
CogStack: An instance of the CogStack class.
140
140
"""
141
+ has_encoded_value = False
142
+ api_id_value :str
143
+ api_key_value :str
144
+
141
145
if not api_key :
142
146
api_key = {"encoded" : input ("Encoded API key: " )}
143
- has_encoded_value = False
144
- api_id_value , api_key_value = None , None
145
- if api_key is not None :
147
+ else :
146
148
if isinstance (api_key , str ):
147
149
# If api_key is a string, it is assumed to be the encoded API key
148
150
encoded = api_key
@@ -161,18 +163,18 @@ def use_api_key_auth(self, api_key: Optional[Dict] = None) -> 'CogStack':
161
163
has_encoded_value = encoded is not None and encoded != ''
162
164
163
165
if (not has_encoded_value ):
164
- api_id_value = api_key ["id" ] \
166
+ api_id_value = str ( api_key ["id" ] \
165
167
if "id" in api_key .keys () and api_key ["id" ] != '' \
166
- else input ("API Id: " )
167
- api_key_value = api_key ["api_key" ] \
168
+ else input ("API Id: " ))
169
+ api_key_value = str ( api_key ["api_key" ] \
168
170
if "api_key" in api_key .keys () and api_key ["api_key" ] != '' \
169
- else getpass .getpass ("API Key: " )
171
+ else getpass .getpass ("API Key: " ))
170
172
171
173
return self .__connect (api_key = encoded if has_encoded_value else (api_id_value , api_key_value ))
172
174
173
175
def __connect (self ,
174
176
basic_auth : Optional [tuple [str ,str ]] = None ,
175
- api_key : Optional [Union [str , tuple [str , str ], None ]] = None ) -> 'CogStack' :
177
+ api_key : Optional [Union [str , tuple [str , str ]]] = None ) -> 'CogStack' :
176
178
""" Connect to Elasticsearch using the provided credentials.
177
179
Parameters
178
180
----------
@@ -189,10 +191,10 @@ def __connect(self,
189
191
Exception: If the connection to Elasticsearch fails.
190
192
"""
191
193
self .elastic = elasticsearch .Elasticsearch (hosts = self .hosts ,
192
- api_key = api_key ,
193
- basic_auth = basic_auth ,
194
- verify_certs = False ,
195
- request_timeout = self .ES_TIMEOUT )
194
+ api_key = api_key ,
195
+ basic_auth = basic_auth ,
196
+ verify_certs = False ,
197
+ request_timeout = self .ES_TIMEOUT )
196
198
if not self .elastic .ping ():
197
199
raise ConnectionError ("CogStack connection failed. " \
198
200
"Please check your host list and credentials and try again." )
@@ -240,6 +242,8 @@ def get_index_fields(self, index: Union[str, Sequence[str]]):
240
242
If the operation fails for any reason.
241
243
"""
242
244
try :
245
+ if len (index ) == 0 :
246
+ raise ValueError ('Provide at least one index or index alias name' )
243
247
all_mappings = self .elastic .indices \
244
248
.get_mapping (index = index , allow_no_indices = False ).body
245
249
columns = ['Field' , 'Type' ]
@@ -282,8 +286,10 @@ def count_search_results(self, index: Union[str, Sequence[str]], query: dict):
282
286
.. code-block:: json
283
287
{"match": {"title": "python"}}}
284
288
"""
289
+ if len (index ) == 0 :
290
+ raise ValueError ('Provide at least one index or index alias name' )
285
291
query = self .__extract_query (query = query )
286
- count = self .elastic .count (index = index , query = query )['count' ]
292
+ count = self .elastic .count (index = index , query = query , allow_no_indices = False )['count' ]
287
293
return f"Number of documents: { format (count , ',' )} "
288
294
289
295
def read_data_with_scan (self ,
@@ -340,20 +346,23 @@ def read_data_with_scan(self,
340
346
If the search fails or cancelled by the user.
341
347
"""
342
348
try :
349
+ if len (index ) == 0 :
350
+ raise ValueError ('Provide at least one index or index alias name' )
343
351
self .__validate_size (size = size )
344
352
if "query" not in query .keys ():
345
353
temp_query = query .copy ()
346
354
query .clear ()
347
355
query ["query" ] = temp_query
348
- pr_bar = None
356
+ pr_bar : tqdm . tqdm = None
349
357
350
358
scan_results = es_helpers .scan (self .elastic ,
351
359
index = index ,
352
360
query = query ,
353
361
size = size ,
354
362
request_timeout = request_timeout ,
355
363
source = False ,
356
- fields = include_fields )
364
+ fields = include_fields ,
365
+ allow_no_indices = False ,)
357
366
all_mapped_results = []
358
367
results = self .elastic .count (index = index , query = query ["query" ])
359
368
pr_bar = tqdm .tqdm (scan_results , total = results ["count" ],
@@ -443,13 +452,15 @@ def read_data_with_scroll(self,
443
452
value of `size` parameter.
444
453
"""
445
454
try :
455
+ if len (index ) == 0 :
456
+ raise ValueError ('Provide at least one index or index alias name' )
446
457
self .__validate_size (size = size )
447
458
query = self .__extract_query (query = query )
448
459
result_count = size
449
460
all_mapped_results = []
450
461
search_result = None
451
- include_fields_map : Sequence [Mapping [str , Any ]] = include_fields \
452
- if include_fields is not None else None
462
+ include_fields_map : Union [ Sequence [Mapping [str , Any ]], None ] = \
463
+ [{ "field" : field } for field in include_fields ] if include_fields is not None else None
453
464
454
465
pr_bar = tqdm .tqdm (desc = "CogStack retrieved..." ,
455
466
disable = not show_progress , colour = 'green' )
@@ -462,6 +473,7 @@ def read_data_with_scroll(self,
462
473
source = False ,
463
474
scroll = "10m" ,
464
475
timeout = f"{ request_timeout } s" ,
476
+ allow_no_indices = False ,
465
477
rest_total_hits_as_int = True )
466
478
467
479
pr_bar .total = search_result .body ['hits' ]['total' ]
@@ -470,6 +482,8 @@ def read_data_with_scroll(self,
470
482
search_scroll_id = search_result .body ['_scroll_id' ]
471
483
all_mapped_results .extend (self .__map_search_results (hits = hits ))
472
484
pr_bar .update (len (hits ))
485
+ if search_result ["_shards" ]["failed" ] > 0 :
486
+ raise LookupError (search_result ["_shards" ]["failures" ])
473
487
474
488
while search_scroll_id and result_count == size :
475
489
# Perform ES scroll request
@@ -559,14 +573,15 @@ def read_data_with_sorting(self,
559
573
which can be used as a function parameter to continue the search.
560
574
"""
561
575
try :
576
+ if len (index ) == 0 :
577
+ raise ValueError ('Provide at least one index or index alias name' )
562
578
result_count = size
563
579
all_mapped_results = []
564
580
if sort is None :
565
581
sort = {'id' : 'asc' }
566
582
search_after_value = search_after
567
- include_fields_map : Sequence [Mapping [str , Any ]] = include_fields \
568
- if include_fields is not None \
569
- else None
583
+ include_fields_map : Union [Sequence [Mapping [str , Any ]], None ] = \
584
+ [{"field" : field } for field in include_fields ] if include_fields is not None else None
570
585
571
586
self .__validate_size (size = size )
572
587
query = self .__extract_query (query = query )
@@ -591,6 +606,7 @@ def read_data_with_sorting(self,
591
606
search_after = search_after_value ,
592
607
timeout = f"{ request_timeout } s" ,
593
608
track_scores = True ,
609
+ track_total_hits = True ,
594
610
allow_no_indices = False ,
595
611
rest_total_hits_as_int = True )
596
612
hits = search_result ['hits' ]['hits' ]
@@ -599,6 +615,8 @@ def read_data_with_sorting(self,
599
615
pr_bar .update (result_count )
600
616
search_after_value = hits [- 1 ]['sort' ]
601
617
pr_bar .total = pr_bar .total if pr_bar .total else search_result .body ['hits' ]['total' ]
618
+ if search_result ["_shards" ]["failed" ] > 0 :
619
+ raise LookupError (search_result ["_shards" ]["failures" ])
602
620
except BaseException as err :
603
621
if isinstance (err , KeyboardInterrupt ):
604
622
pr_bar .bar_format = "%s{l_bar}%s{bar}%s{r_bar}" % ("\033 [0;33m" ,
@@ -619,12 +637,12 @@ def read_data_with_sorting(self,
619
637
620
638
def __extract_query (self , query : dict ):
621
639
if "query" in query .keys ():
622
- query = query ['query' ]
640
+ return query ['query' ]
623
641
return query
624
642
625
643
def __validate_size (self , size ):
626
644
if size > 10000 :
627
- raise ValueError ('Size must not be greater then 10000' )
645
+ raise ValueError ('Size must not be greater than 10000' )
628
646
629
647
def __map_search_results (self , hits : Iterable ):
630
648
hit : dict
0 commit comments