Skip to content

Commit 5705154

Browse files
authored
Update magics for new Neptune Analytics API (#560)
* Support new Analytics API * Fix %load, better %summary mode messaging * update changelog
1 parent bb96dd8 commit 5705154

File tree

3 files changed

+142
-45
lines changed

3 files changed

+142
-45
lines changed

ChangeLog.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ Starting with v1.31.6, this file will contain a record of major features and upd
55
## Upcoming
66
- New Neptune Analytics notebook - Vector Similarity Algorithms ([Link to PR](https://github.com/aws/graph-notebook/pull/555))
77
- Path: 02-Neptune-Analytics > 02-Graph-Algorithms > 06-Vector-Similarity-Algorithms
8-
- Deprecated Python 3.7 support ([Link to PR](https://github.com/aws/graph-notebook/pull/551))
8+
- Updated various Neptune magics for new Analytics API ([Link to PR](https://github.com/aws/graph-notebook/pull/560))
9+
- Added `%graph_notebook_service` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/560))
910
- Added unit abbreviation support to `--max-content-length` ([Link to PR](https://github.com/aws/graph-notebook/pull/553))
11+
- Deprecated Python 3.7 support ([Link to PR](https://github.com/aws/graph-notebook/pull/551))
1012

1113
## Release 4.0.2 (Dec 14, 2023)
1214
- Fixed `neptune_ml_utils` imports in `03-Neptune-ML` samples ([Link to PR](https://github.com/aws/graph-notebook/pull/546))

src/graph_notebook/magics/graph_magic.py

Lines changed: 105 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@
4242
from graph_notebook.decorators.decorators import display_exceptions, magic_variables, neptune_db_only
4343
from graph_notebook.magics.ml import neptune_ml_magic_handler, generate_neptune_ml_parser
4444
from graph_notebook.magics.streams import StreamViewer
45-
from graph_notebook.neptune.client import ClientBuilder, Client,PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
45+
from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
4646
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
4747
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
4848
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
4949
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \
50-
STATISTICS_LANGUAGE_INPUTS, STATISTICS_MODES, SUMMARY_MODES, \
51-
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT
50+
STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \
51+
SPARQL_EXPLAIN_MODES, OPENCYPHER_EXPLAIN_MODES, OPENCYPHER_PLAN_CACHE_MODES, OPENCYPHER_DEFAULT_TIMEOUT, \
52+
OPENCYPHER_STATUS_STATE_MODES, normalize_service_name
5253
from graph_notebook.network import SPARQLNetwork
5354
from graph_notebook.network.gremlin.GremlinNetwork import parse_pattern_list_str, GremlinNetwork
5455
from graph_notebook.visualization.rows_and_columns import sparql_get_rows_and_columns, opencypher_get_rows_and_columns
@@ -255,22 +256,31 @@ def get_load_ids(neptune_client):
255256
return ids, res
256257

257258

258-
def process_statistics_400(is_summary: bool, response):
259+
def process_statistics_400(response, is_summary: bool = False, is_analytics: bool = False):
259260
bad_request_res = json.loads(response.text)
260261
res_code = bad_request_res['code']
261262
if res_code == 'StatisticsNotAvailableException':
262-
print("No statistics found. Please ensure that auto-generation of DFE statistics is enabled by running "
263-
"'%statistics' and checking if 'autoCompute' if set to True. Alternately, you can manually "
264-
"trigger statistics generation by running: '%statistics --mode refresh'.")
263+
print("No statistics found. ", end="")
264+
if not is_analytics:
265+
print("Please ensure that auto-generation of DFE statistics is enabled by running '%statistics' and "
266+
"checking if 'autoCompute' if set to True. Alternately, you can manually trigger statistics "
267+
"generation by running: '%statistics --mode refresh'.")
268+
return
265269
elif res_code == "BadRequestException":
266-
print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size r5.large or "
267-
"greater in order to have DFE statistics enabled.")
268-
if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]:
269-
print("\nPlease also note that the Graph Summary API is only available in Neptune engine version 1.2.1.0 "
270-
"and later.")
271-
else:
272-
print("Query encountered 400 error, please see below.")
270+
if is_analytics:
271+
if bad_request_res["message"] == 'Bad route: /summary':
272+
logger.debug("Encountered bad route exception for Analytics, retrying with legacy statistics endpoint.")
273+
return 1
274+
else:
275+
print("Unable to query the statistics endpoint. Please check that your Neptune instance is of size "
276+
"r5.large or greater in order to have DFE statistics enabled.")
277+
if is_summary and "Statistics is disabled" not in bad_request_res["detailedMessage"]:
278+
print("\nPlease also note that the Graph Summary API is only available in Neptune engine version "
279+
"1.2.1.0 and later.")
280+
return
281+
print("Query encountered 400 error, please see below.")
273282
print(f"\nFull response: {bad_request_res}")
283+
return
274284

275285

276286
def mcl_to_bytes(mcl):
@@ -445,6 +455,7 @@ def stream_viewer(self,line):
445455
@line_magic
446456
@needs_local_scope
447457
@display_exceptions
458+
@neptune_db_only
448459
def statistics(self, line, local_ns: dict = None):
449460
parser = argparse.ArgumentParser()
450461
parser.add_argument('language', nargs='?', type=str.lower, default="propertygraph",
@@ -476,9 +487,9 @@ def statistics(self, line, local_ns: dict = None):
476487
statistics_res = self.client.statistics(args.language, args.summary, mode)
477488
if statistics_res.status_code == 400:
478489
if args.summary:
479-
process_statistics_400(True, statistics_res)
490+
process_statistics_400(statistics_res)
480491
else:
481-
process_statistics_400(False, statistics_res)
492+
process_statistics_400(statistics_res)
482493
return
483494
statistics_res.raise_for_status()
484495
statistics_res_json = statistics_res.json()
@@ -508,10 +519,21 @@ def summary(self, line, local_ns: dict = None):
508519
else:
509520
mode = "basic"
510521

511-
summary_res = self.client.statistics(args.language, True, mode)
522+
language_ep = args.language
523+
if self.client.is_analytics_domain():
524+
is_analytics = True
525+
if language_ep in STATISTICS_LANGUAGE_INPUTS_SPARQL:
526+
print("SPARQL is not supported for Neptune Analytics, defaulting to PropertyGraph.")
527+
language_ep = 'propertygraph'
528+
else:
529+
is_analytics = False
530+
summary_res = self.client.statistics(language_ep, True, mode, is_analytics)
512531
if summary_res.status_code == 400:
513-
process_statistics_400(True, summary_res)
514-
return
532+
retry_legacy = process_statistics_400(summary_res, is_summary=True, is_analytics=is_analytics)
533+
if retry_legacy == 1:
534+
summary_res = self.client.statistics(language_ep, True, mode, False)
535+
else:
536+
return
515537
summary_res.raise_for_status()
516538
summary_res_json = summary_res.json()
517539
if not args.silent:
@@ -530,6 +552,16 @@ def graph_notebook_host(self, line):
530552
self._generate_client_from_config(self.graph_notebook_config)
531553
print(f'set host to {self.graph_notebook_config.host}')
532554

555+
@line_magic
556+
def graph_notebook_service(self, line):
557+
if line == '':
558+
print(f'current service name: {self.graph_notebook_config.neptune_service}')
559+
return
560+
561+
self.graph_notebook_config.neptune_service = normalize_service_name(line)
562+
self._generate_client_from_config(self.graph_notebook_config)
563+
print(f'set service name to {self.graph_notebook_config.neptune_service}')
564+
533565
@magic_variables
534566
@cell_magic
535567
@needs_local_scope
@@ -1177,6 +1209,7 @@ def opencypher_status(self, line='', local_ns: dict = None):
11771209
@line_magic
11781210
@needs_local_scope
11791211
@display_exceptions
1212+
@neptune_db_only
11801213
def status(self, line='', local_ns: dict = None):
11811214
logger.info(f'calling for status on endpoint {self.graph_notebook_config.host}')
11821215
parser = argparse.ArgumentParser()
@@ -1547,6 +1580,7 @@ def load(self, line='', local_ns: dict = None):
15471580
value=str(args.concurrency),
15481581
placeholder=1,
15491582
min=1,
1583+
max=2**16,
15501584
disabled=False,
15511585
layout=widgets.Layout(display=concurrency_hbox_visibility,
15521586
width=widget_width)
@@ -1556,6 +1590,7 @@ def load(self, line='', local_ns: dict = None):
15561590
value=args.periodic_commit,
15571591
placeholder=0,
15581592
min=0,
1593+
max=1000000,
15591594
disabled=False,
15601595
layout=widgets.Layout(display=periodic_commit_hbox_visibility,
15611596
width=widget_width)
@@ -1770,13 +1805,12 @@ def on_button_clicked(b):
17701805
source_format_validation_label = widgets.HTML('<p style="color:red;">Format cannot be blank.</p>')
17711806
source_format_hbox.children += (source_format_validation_label,)
17721807

1773-
if not arn.value.startswith('arn:aws') and source.value.startswith(
1774-
"s3://"): # only do this validation if we are using an s3 bucket.
1775-
validated = False
1776-
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
1777-
arn_hbox.children += (arn_validation_label,)
1778-
17791808
if load_type == 'bulk':
1809+
if not arn.value.startswith('arn:aws') and source.value.startswith(
1810+
"s3://"): # only do this validation if we are using an s3 bucket.
1811+
validated = False
1812+
arn_validation_label = widgets.HTML('<p style="color:red;">Load ARN must start with "arn:aws"</p>')
1813+
arn_hbox.children += (arn_validation_label,)
17801814
dependencies_list = list(filter(None, dependencies.value.split('\n')))
17811815
if not len(dependencies_list) < 64:
17821816
validated = False
@@ -3105,9 +3139,15 @@ def handle_opencypher_status(self, line, local_ns):
31053139
parser.add_argument('-c', '--cancelQuery', action='store_true', default=False,
31063140
help='Tells the status command to cancel a query. This parameter does not take a value.')
31073141
parser.add_argument('-w', '--includeWaiting', action='store_true', default=False,
3108-
help='When set to true and other parameters are not present, causes status information '
3109-
'for waiting queries to be returned as well as for running queries. '
3110-
'This parameter does not take a value.')
3142+
help='Neptune DB only. When set to true and other parameters are not present, causes '
3143+
'status information for waiting queries to be returned as well as for running '
3144+
'queries. This parameter does not take a value.')
3145+
parser.add_argument('--state', type=str.upper, default='ALL',
3146+
help=f'Neptune Analytics only. Specifies what subset of query states to retrieve the '
3147+
f'status of. Default is ALL. Accepted values: ${OPENCYPHER_STATUS_STATE_MODES}')
3148+
parser.add_argument('-m', '--maxResults', type=int, default=200,
3149+
help=f'Neptune Analytics only. Sets an upper limit on the set of returned queries whose '
3150+
f'status matches --state. Default is 200.')
31113151
parser.add_argument('-s', '--silent-cancel', action='store_true', default=False,
31123152
help='If silent_cancel=true then the running query is cancelled and the HTTP response '
31133153
'code is 200. If silent_cancel is not present or silent_cancel=false, '
@@ -3116,21 +3156,50 @@ def handle_opencypher_status(self, line, local_ns):
31163156
parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
31173157
args = parser.parse_args(line.split())
31183158

3159+
using_analytics = self.client.is_analytics_domain()
31193160
if not args.cancelQuery:
3120-
if args.includeWaiting and not args.queryId:
3121-
res = self.client.opencypher_status(include_waiting=args.includeWaiting)
3161+
query_id = ''
3162+
include_waiting = None
3163+
state = ''
3164+
max_results = None
3165+
if args.includeWaiting and not args.queryId and not self.client.is_analytics_domain():
3166+
include_waiting = args.includeWaiting
3167+
elif args.state and not args.queryId and self.client.is_analytics_domain():
3168+
state = args.state
3169+
max_results = args.maxResults
31223170
else:
3123-
res = self.client.opencypher_status(query_id=args.queryId)
3171+
query_id = args.queryId
3172+
res = self.client.opencypher_status(query_id=query_id,
3173+
include_waiting=include_waiting,
3174+
state=state,
3175+
max_results=max_results,
3176+
use_analytics_endpoint=using_analytics)
3177+
if using_analytics and res.status_code == 400 and 'Bad route: /queries' in res.json()["message"]:
3178+
res = self.client.opencypher_status(query_id=query_id,
3179+
include_waiting=include_waiting,
3180+
state=state,
3181+
max_results=max_results,
3182+
use_analytics_endpoint=False)
31243183
res.raise_for_status()
31253184
else:
31263185
if args.queryId == '':
31273186
if not args.silent:
31283187
print(OPENCYPHER_CANCEL_HINT_MSG)
31293188
return
31303189
else:
3131-
res = self.client.opencypher_cancel(args.queryId, args.silent_cancel)
3190+
res = self.client.opencypher_cancel(args.queryId,
3191+
silent=args.silent_cancel,
3192+
use_analytics_endpoint=using_analytics)
3193+
if using_analytics and res.status_code == 400 and 'Bad route: /queries' in res.json()["message"]:
3194+
res = self.client.opencypher_cancel(args.queryId,
3195+
silent=args.silent_cancel,
3196+
use_analytics_endpoint=False)
31323197
res.raise_for_status()
3133-
js = res.json()
3134-
store_to_ns(args.store_to, js, local_ns)
3135-
if not args.silent:
3136-
print(json.dumps(js, indent=2))
3198+
if using_analytics and args.cancelQuery:
3199+
if not args.silent:
3200+
print(f'Submitted cancellation request for query ID: {args.queryId}')
3201+
else:
3202+
js = res.json()
3203+
store_to_ns(args.store_to, js, local_ns)
3204+
if not args.silent:
3205+
print(json.dumps(js, indent=2))

src/graph_notebook/neptune/client.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,15 @@
122122

123123
STATISTICS_MODES = ["", "status", "disableAutoCompute", "enableAutoCompute", "refresh", "delete"]
124124
SUMMARY_MODES = ["", "basic", "detailed"]
125-
STATISTICS_LANGUAGE_INPUTS = ["propertygraph", "pg", "gremlin", "oc", "opencypher", "sparql", "rdf"]
125+
STATISTICS_LANGUAGE_INPUTS_PG = ["propertygraph", "pg", "gremlin", "oc", "opencypher"]
126+
STATISTICS_LANGUAGE_INPUTS_SPARQL = ["sparql", "rdf"]
127+
STATISTICS_LANGUAGE_INPUTS = STATISTICS_LANGUAGE_INPUTS_PG + STATISTICS_LANGUAGE_INPUTS_SPARQL
126128

127129
SPARQL_EXPLAIN_MODES = ['dynamic', 'static', 'details']
128130
OPENCYPHER_EXPLAIN_MODES = ['dynamic', 'static', 'details']
129131
OPENCYPHER_PLAN_CACHE_MODES = ['auto', 'enabled', 'disabled']
130132
OPENCYPHER_DEFAULT_TIMEOUT = 120000
133+
OPENCYPHER_STATUS_STATE_MODES = ['ALL', 'RUNNING', 'WAITING', 'CANCELLING']
131134

132135

133136
def is_allowed_neptune_host(hostname: str, host_allowlist: list):
@@ -405,7 +408,7 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
405408
if plan_cache:
406409
data['planCache'] = plan_cache
407410
if query_timeout:
408-
headers['query_timeout_millis'] = str(query_timeout)
411+
data['queryTimeoutMilliseconds'] = str(query_timeout)
409412
else:
410413
url += 'db/neo4j/tx/commit'
411414
headers['content-type'] = 'application/json'
@@ -441,16 +444,20 @@ def opencyper_bolt(self, query: str, **kwargs):
441444
driver.close()
442445
return data
443446

444-
def opencypher_status(self, query_id: str = '', include_waiting: bool = False):
447+
def opencypher_status(self, query_id: str = '', include_waiting: bool = False, state: str = '',
448+
max_results: int = None, use_analytics_endpoint: bool = False):
449+
if use_analytics_endpoint:
450+
return self._analytics_query_status(query_id=query_id, state=state, max_results=max_results)
445451
kwargs = {}
446452
if include_waiting:
447453
kwargs['includeWaiting'] = True
448454
return self._query_status('openCypher', query_id=query_id, **kwargs)
449455

450-
def opencypher_cancel(self, query_id, silent: bool = False):
456+
def opencypher_cancel(self, query_id, silent: bool = False, use_analytics_endpoint: bool = False):
451457
if type(query_id) is not str or query_id == '':
452458
raise ValueError('query_id must be a non-empty string')
453-
459+
if use_analytics_endpoint:
460+
return self._analytics_query_status(query_id=query_id, cancel_query=True)
454461
return self._query_status('openCypher', query_id=query_id, cancelQuery=True, silent=silent)
455462

456463
def get_opencypher_driver(self):
@@ -808,7 +815,25 @@ def _query_status(self, language: str, *, query_id: str = '', **kwargs) -> reque
808815
res = self._http_session.send(req, verify=self.ssl_verify)
809816
return res
810817

811-
def statistics(self, language: str, summary: bool = False, mode: str = '') -> requests.Response:
818+
def _analytics_query_status(self, query_id: str = '', state: str = '', max_results: int = None,
819+
cancel_query: bool = False) -> requests.Response:
820+
url = f'{self._http_protocol}://{self.host}:{self.port}/queries'
821+
if query_id != '':
822+
url += f'/{query_id}'
823+
elif state != '':
824+
url += f'?state={state}&maxResults={max_results}'
825+
826+
method = 'DELETE' if cancel_query else 'GET'
827+
828+
headers = {
829+
'Content-Type': 'application/x-www-form-urlencoded'
830+
}
831+
req = self._prepare_request(method, url, headers=headers)
832+
res = self._http_session.send(req, verify=self.ssl_verify)
833+
return res
834+
835+
def statistics(self, language: str, summary: bool = False, mode: str = '',
836+
use_analytics_endpoint: bool = False) -> requests.Response:
812837
headers = {
813838
'Accept': 'application/json'
814839
}
@@ -817,11 +842,12 @@ def statistics(self, language: str, summary: bool = False, mode: str = '') -> re
817842
elif language == "sparql":
818843
language = "rdf"
819844

820-
url = f'{self._http_protocol}://{self.host}:{self.port}/{language}/statistics'
845+
base_url = f'{self._http_protocol}://{self.host}:{self.port}'
846+
url = base_url + f'/{language}/statistics'
821847
data = {'mode': mode}
822848

823849
if summary:
824-
summary_url = url + '/summary'
850+
summary_url = (base_url if use_analytics_endpoint else url) + '/summary'
825851
if mode:
826852
summary_mode_param = '?mode=' + mode
827853
summary_url += summary_mode_param

0 commit comments

Comments
 (0)