Skip to content

Commit cd05f54

Browse files
committed
Fixed type matching issues flagged during build
1 parent b74cdcf commit cd05f54

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

cogstack2.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class CogStack():
2424

2525
def __init__(self, hosts: List[str]):
2626
self.hosts = hosts
27-
self.elastic = None
27+
self.elastic: elasticsearch.Elasticsearch
2828

2929
@classmethod
3030
def with_basic_auth(cls,
@@ -138,11 +138,13 @@ def use_api_key_auth(self, api_key: Optional[Dict] = None) -> 'CogStack':
138138
-------
139139
CogStack: An instance of the CogStack class.
140140
"""
141+
has_encoded_value = False
142+
api_id_value:str
143+
api_key_value:str
144+
141145
if not api_key:
142146
api_key = {"encoded": input("Encoded API key: ")}
143-
has_encoded_value = False
144-
api_id_value, api_key_value = None, None
145-
if api_key is not None:
147+
else:
146148
if isinstance(api_key, str):
147149
# If api_key is a string, it is assumed to be the encoded API key
148150
encoded = api_key
@@ -161,18 +163,18 @@ def use_api_key_auth(self, api_key: Optional[Dict] = None) -> 'CogStack':
161163
has_encoded_value = encoded is not None and encoded != ''
162164

163165
if(not has_encoded_value):
164-
api_id_value = api_key["id"] \
166+
api_id_value = str(api_key["id"] \
165167
if "id" in api_key.keys() and api_key["id"] != '' \
166-
else input("API Id: ")
167-
api_key_value = api_key["api_key"] \
168+
else input("API Id: "))
169+
api_key_value = str(api_key["api_key"] \
168170
if "api_key" in api_key.keys() and api_key["api_key"] != '' \
169-
else getpass.getpass("API Key: ")
171+
else getpass.getpass("API Key: "))
170172

171173
return self.__connect(api_key=encoded if has_encoded_value else (api_id_value, api_key_value))
172174

173175
def __connect(self,
174176
basic_auth : Optional[tuple[str,str]] = None,
175-
api_key: Optional[Union[str, tuple[str, str], None]] = None) -> 'CogStack':
177+
api_key: Optional[Union[str, tuple[str, str]]] = None) -> 'CogStack':
176178
""" Connect to Elasticsearch using the provided credentials.
177179
Parameters
178180
----------
@@ -189,10 +191,10 @@ def __connect(self,
189191
Exception: If the connection to Elasticsearch fails.
190192
"""
191193
self.elastic = elasticsearch.Elasticsearch(hosts=self.hosts,
192-
api_key=api_key,
193-
basic_auth=basic_auth,
194-
verify_certs=False,
195-
request_timeout=self.ES_TIMEOUT)
194+
api_key=api_key,
195+
basic_auth=basic_auth,
196+
verify_certs=False,
197+
request_timeout=self.ES_TIMEOUT)
196198
if not self.elastic.ping():
197199
raise ConnectionError("CogStack connection failed. " \
198200
"Please check your host list and credentials and try again.")
@@ -240,6 +242,8 @@ def get_index_fields(self, index: Union[str, Sequence[str]]):
240242
If the operation fails for any reason.
241243
"""
242244
try:
245+
if len(index) == 0:
246+
raise ValueError('Provide at least one index or index alias name')
243247
all_mappings = self.elastic.indices\
244248
.get_mapping(index=index, allow_no_indices=False).body
245249
columns= ['Field', 'Type']
@@ -282,8 +286,10 @@ def count_search_results(self, index: Union[str, Sequence[str]], query: dict):
282286
.. code-block:: json
283287
{"match": {"title": "python"}}}
284288
"""
289+
if len(index) == 0:
290+
raise ValueError('Provide at least one index or index alias name')
285291
query = self.__extract_query(query=query)
286-
count = self.elastic.count(index=index, query=query)['count']
292+
count = self.elastic.count(index=index, query=query, allow_no_indices=False)['count']
287293
return f"Number of documents: {format(count, ',')}"
288294

289295
def read_data_with_scan(self,
@@ -340,20 +346,23 @@ def read_data_with_scan(self,
340346
If the search fails or cancelled by the user.
341347
"""
342348
try:
349+
if len(index) == 0:
350+
raise ValueError('Provide at least one index or index alias name')
343351
self.__validate_size(size=size)
344352
if "query" not in query.keys():
345353
temp_query = query.copy()
346354
query.clear()
347355
query["query"] = temp_query
348-
pr_bar = None
356+
pr_bar: tqdm.tqdm = None
349357

350358
scan_results = es_helpers.scan(self.elastic,
351359
index=index,
352360
query=query,
353361
size=size,
354362
request_timeout=request_timeout,
355363
source=False,
356-
fields = include_fields)
364+
fields = include_fields,
365+
allow_no_indices=False,)
357366
all_mapped_results = []
358367
results = self.elastic.count(index=index, query=query["query"])
359368
pr_bar = tqdm.tqdm(scan_results, total=results["count"],
@@ -443,13 +452,15 @@ def read_data_with_scroll(self,
443452
value of `size` parameter.
444453
"""
445454
try:
455+
if len(index) == 0:
456+
raise ValueError('Provide at least one index or index alias name')
446457
self.__validate_size(size=size)
447458
query = self.__extract_query(query=query)
448459
result_count = size
449460
all_mapped_results =[]
450461
search_result=None
451-
include_fields_map: Sequence[Mapping[str, Any]] = include_fields \
452-
if include_fields is not None else None
462+
include_fields_map: Union[Sequence[Mapping[str, Any]], None] = \
463+
[{"field": field} for field in include_fields] if include_fields is not None else None
453464

454465
pr_bar = tqdm.tqdm(desc="CogStack retrieved...",
455466
disable=not show_progress, colour='green')
@@ -462,6 +473,7 @@ def read_data_with_scroll(self,
462473
source=False,
463474
scroll="10m",
464475
timeout=f"{request_timeout}s",
476+
allow_no_indices=False,
465477
rest_total_hits_as_int=True)
466478

467479
pr_bar.total = search_result.body['hits']['total']
@@ -470,6 +482,8 @@ def read_data_with_scroll(self,
470482
search_scroll_id = search_result.body['_scroll_id']
471483
all_mapped_results.extend(self.__map_search_results(hits=hits))
472484
pr_bar.update(len(hits))
485+
if search_result["_shards"]["failed"] > 0:
486+
raise LookupError(search_result["_shards"]["failures"])
473487

474488
while search_scroll_id and result_count == size:
475489
# Perform ES scroll request
@@ -559,14 +573,15 @@ def read_data_with_sorting(self,
559573
which can be used as a function parameter to continue the search.
560574
"""
561575
try:
576+
if len(index) == 0:
577+
raise ValueError('Provide at least one index or index alias name')
562578
result_count = size
563579
all_mapped_results =[]
564580
if sort is None:
565581
sort = {'id': 'asc'}
566582
search_after_value = search_after
567-
include_fields_map: Sequence[Mapping[str, Any]] = include_fields \
568-
if include_fields is not None \
569-
else None
583+
include_fields_map: Union[Sequence[Mapping[str, Any]], None] = \
584+
[{"field": field} for field in include_fields] if include_fields is not None else None
570585

571586
self.__validate_size(size=size)
572587
query = self.__extract_query(query=query)
@@ -591,6 +606,7 @@ def read_data_with_sorting(self,
591606
search_after=search_after_value,
592607
timeout=f"{request_timeout}s",
593608
track_scores=True,
609+
track_total_hits=True,
594610
allow_no_indices=False,
595611
rest_total_hits_as_int=True)
596612
hits = search_result['hits']['hits']
@@ -599,6 +615,8 @@ def read_data_with_sorting(self,
599615
pr_bar.update(result_count)
600616
search_after_value = hits[-1]['sort']
601617
pr_bar.total = pr_bar.total if pr_bar.total else search_result.body['hits']['total']
618+
if search_result["_shards"]["failed"] > 0:
619+
raise LookupError(search_result["_shards"]["failures"])
602620
except BaseException as err:
603621
if isinstance(err, KeyboardInterrupt):
604622
pr_bar.bar_format = "%s{l_bar}%s{bar}%s{r_bar}" % ("\033[0;33m",
@@ -619,12 +637,12 @@ def read_data_with_sorting(self,
619637

620638
def __extract_query(self, query: dict):
621639
if "query" in query.keys():
622-
query = query['query']
640+
return query['query']
623641
return query
624642

625643
def __validate_size(self, size):
626644
if size > 10000:
627-
raise ValueError('Size must not be greater then 10000')
645+
raise ValueError('Size must not be greater than 10000')
628646

629647
def __map_search_results(self, hits: Iterable):
630648
hit: dict

0 commit comments

Comments
 (0)