1919 ClientMiddleware ,
2020 ClientMiddlewareFactory ,
2121 FlightDescriptor ,
22+ FlightInternalError ,
2223 FlightMetadataReader ,
2324 FlightStreamWriter ,
2425 FlightTimedOutError ,
3435 wait_exponential ,
3536)
3637
38+ from graphdatascience .retry_utils .retry_config import RetryConfig
3739from graphdatascience .retry_utils .retry_utils import before_log
3840
3941from ..semantic_version .semantic_version import SemanticVersion
4042from ..version import __version__
4143from .arrow_endpoint_version import ArrowEndpointVersion
4244from .arrow_info import ArrowInfo
4345
44- _arrow_client_logger = logging .getLogger ("gds_arrow_client" )
45-
4646
4747class GdsArrowClient :
4848 @staticmethod
@@ -53,6 +53,7 @@ def create(
5353 disable_server_verification : bool = False ,
5454 tls_root_certs : Optional [bytes ] = None ,
5555 connection_string_override : Optional [str ] = None ,
56+ retry_config : Optional [RetryConfig ] = None ,
5657 ) -> GdsArrowClient :
5758 connection_string : str
5859 if connection_string_override is not None :
@@ -64,8 +65,20 @@ def create(
6465
6566 arrow_endpoint_version = ArrowEndpointVersion .from_arrow_info (arrow_info .versions )
6667
68+ if retry_config is None :
69+ retry_config = RetryConfig (
70+ retry = retry_any (
71+ retry_if_exception_type (FlightTimedOutError ),
72+ retry_if_exception_type (FlightUnavailableError ),
73+ retry_if_exception_type (FlightInternalError ),
74+ ),
75+ stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
76+ wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
77+ )
78+
6779 return GdsArrowClient (
6880 host ,
81+ retry_config ,
6982 int (port ),
7083 auth ,
7184 encrypted ,
@@ -77,6 +90,7 @@ def create(
7790 def __init__ (
7891 self ,
7992 host : str ,
93+ retry_config : RetryConfig ,
8094 port : int = 8491 ,
8195 auth : Optional [tuple [str , str ]] = None ,
8296 encrypted : bool = False ,
@@ -105,6 +119,8 @@ def __init__(
105119 The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
106120 user_agent: Optional[str]
107121 The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
122+ retry_config: Optional[RetryConfig]
123+ The retry configuration to use for the Arrow requests send by the client.
108124 """
109125 self ._arrow_endpoint_version = arrow_endpoint_version
110126 self ._host = host
@@ -114,6 +130,8 @@ def __init__(
114130 self ._disable_server_verification = disable_server_verification
115131 self ._tls_root_certs = tls_root_certs
116132 self ._user_agent = user_agent
133+ self ._retry_config = retry_config
134+ self ._logger = logging .getLogger ("gds_arrow_client" )
117135
118136 if auth :
119137 self ._auth_middleware = AuthMiddleware (auth )
@@ -151,13 +169,6 @@ def connection_info(self) -> tuple[str, int]:
151169 """
152170 return self ._host , self ._port
153171
154- @retry (
155- reraise = True ,
156- before = before_log ("Request token" , _arrow_client_logger , logging .DEBUG ),
157- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
158- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
159- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
160- )
161172 def request_token (self ) -> Optional [str ]:
162173 """
163174 Requests a token from the server and returns it.
@@ -167,9 +178,21 @@ def request_token(self) -> Optional[str]:
167178 Optional[str]
168179 a token from the server and returns it.
169180 """
170- if self ._auth :
181+
182+ @retry (
183+ reraise = True ,
184+ before = before_log ("Request token" , self ._logger , logging .DEBUG ),
185+ retry = self ._retry_config .retry ,
186+ stop = self ._retry_config .stop ,
187+ wait = self ._retry_config .wait ,
188+ )
189+ def auth_with_retry () -> None :
171190 client = self ._client ()
172- client .authenticate_basic_token (self ._auth [0 ], self ._auth [1 ])
191+ if self ._auth :
192+ client .authenticate_basic_token (self ._auth [0 ], self ._auth [1 ])
193+
194+ if self ._auth :
195+ auth_with_retry ()
173196 return self ._auth_middleware .token ()
174197 else :
175198 return "IGNORED"
@@ -220,7 +243,7 @@ def get_node_properties(
220243 if node_labels :
221244 config ["node_labels" ] = node_labels
222245
223- return self ._do_get (database , graph_name , proc , concurrency , config )
246+ return self ._do_get_with_retry (database , graph_name , proc , concurrency , config )
224247
225248 def get_node_labels (self , graph_name : str , database : str , concurrency : Optional [int ] = None ) -> pandas .DataFrame :
226249 """
@@ -240,7 +263,7 @@ def get_node_labels(self, graph_name: str, database: str, concurrency: Optional[
240263 DataFrame
241264 The requested nodes as a DataFrame
242265 """
243- return self ._do_get (database , graph_name , "gds.graph.nodeLabels.stream" , concurrency , {})
266+ return self ._do_get_with_retry (database , graph_name , "gds.graph.nodeLabels.stream" , concurrency , {})
244267
245268 def get_relationships (
246269 self , graph_name : str , database : str , relationship_types : list [str ], concurrency : Optional [int ] = None
@@ -264,7 +287,7 @@ def get_relationships(
264287 DataFrame
265288 The requested relationships as a DataFrame
266289 """
267- return self ._do_get (
290+ return self ._do_get_with_retry (
268291 database ,
269292 graph_name ,
270293 "gds.graph.relationships.stream" ,
@@ -312,7 +335,7 @@ def get_relationship_properties(
312335 if relationship_types :
313336 config ["relationship_types" ] = relationship_types
314337
315- return self ._do_get (database , graph_name , proc , concurrency , config )
338+ return self ._do_get_with_retry (database , graph_name , proc , concurrency , config )
316339
317340 def create_graph (
318341 self ,
@@ -598,40 +621,31 @@ def _client(self) -> flight.FlightClient:
598621 self ._flight_client = self ._instantiate_flight_client ()
599622 return self ._flight_client
600623
601- @retry (
602- reraise = True ,
603- before = before_log ("Send action" , _arrow_client_logger , logging .DEBUG ),
604- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
605- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
606- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
607- )
608624 def _send_action (self , action_type : str , meta_data : dict [str , Any ]) -> dict [str , Any ]:
609625 action_type = self ._versioned_action_type (action_type )
626+ client = self ._client ()
610627
611- try :
612- client = self ._client ()
613- result = client .do_action (flight .Action (action_type , json .dumps (meta_data ).encode ("utf-8" )))
628+ @retry (
629+ reraise = True ,
630+ before = before_log ("Send action" , self ._logger , logging .DEBUG ),
631+ retry = self ._retry_config .retry ,
632+ stop = self ._retry_config .stop ,
633+ wait = self ._retry_config .wait ,
634+ )
635+ def send_with_retry () -> dict [str , Any ]:
636+ try :
637+ result = client .do_action (flight .Action (action_type , json .dumps (meta_data ).encode ("utf-8" )))
614638
615- # Consume result fully to sanity check and avoid cancelled streams
616- collected_result = list (result )
617- assert len (collected_result ) == 1
639+ # Consume result fully to sanity check and avoid cancelled streams
640+ collected_result = list (result )
641+ assert len (collected_result ) == 1
618642
619- return json .loads (collected_result [0 ].body .to_pybytes ().decode ()) # type: ignore
620- except Exception as e :
621- self .handle_flight_error (e )
622- raise e # unreachable
623-
624- @retry (
625- reraise = True ,
626- before = before_log ("Do put" , _arrow_client_logger , logging .DEBUG ),
627- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
628- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
629- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
630- )
631- def _safe_do_put (
632- self , upload_descriptor : FlightDescriptor , schema : Schema
633- ) -> tuple [FlightStreamWriter , FlightMetadataReader ]:
634- return self ._client ().do_put (upload_descriptor , schema ) # type: ignore
643+ return json .loads (collected_result [0 ].body .to_pybytes ().decode ()) # type: ignore
644+ except Exception as e :
645+ self .handle_flight_error (e )
646+ raise e # unreachable
647+
648+ return send_with_retry ()
635649
636650 def _upload_data (
637651 self ,
@@ -651,18 +665,26 @@ def _upload_data(
651665 flight_descriptor = self ._versioned_flight_descriptor ({"name" : graph_name , "entity_type" : entity_type })
652666 upload_descriptor = flight .FlightDescriptor .for_command (json .dumps (flight_descriptor ).encode ("utf-8" ))
653667
654- put_stream , ack_stream = self ._safe_do_put (upload_descriptor , batches [0 ].schema )
668+ @retry (
669+ reraise = True ,
670+ before = before_log ("Do put" , self ._logger , logging .DEBUG ),
671+ retry = self ._retry_config .retry ,
672+ stop = self ._retry_config .stop ,
673+ wait = self ._retry_config .wait ,
674+ )
675+ def safe_do_put (
676+ upload_descriptor : FlightDescriptor , schema : Schema
677+ ) -> tuple [FlightStreamWriter , FlightMetadataReader ]:
678+ return self ._client ().do_put (upload_descriptor , schema ) # type: ignore
679+
680+ put_stream , ack_stream = safe_do_put (upload_descriptor , batches [0 ].schema )
655681
656682 @retry (
657683 reraise = True ,
658- before = before_log ("Upload batch" , _arrow_client_logger , logging .DEBUG ),
659- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
660- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
661- retry = (
662- retry_if_exception_type (flight .FlightUnavailableError )
663- | retry_if_exception_type (flight .FlightTimedOutError )
664- | retry_if_exception_type (flight .FlightInternalError )
665- ),
684+ before = before_log ("Upload batch" , self ._logger , logging .DEBUG ),
685+ retry = self ._retry_config .retry ,
686+ stop = self ._retry_config .stop ,
687+ wait = self ._retry_config .wait ,
666688 )
667689 def upload_batch (p : RecordBatch ) -> None :
668690 put_stream .write_batch (p )
@@ -676,13 +698,26 @@ def upload_batch(p: RecordBatch) -> None:
676698 except Exception as e :
677699 GdsArrowClient .handle_flight_error (e )
678700
679- @retry (
680- reraise = True ,
681- before = before_log ("Do get" , _arrow_client_logger , logging .DEBUG ),
682- retry = retry_any (retry_if_exception_type (FlightTimedOutError ), retry_if_exception_type (FlightUnavailableError )),
683- stop = (stop_after_delay (10 ) | stop_after_attempt (5 )),
684- wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
685- )
701+ def _do_get_with_retry (
702+ self ,
703+ database : str ,
704+ graph_name : str ,
705+ procedure_name : str ,
706+ concurrency : Optional [int ],
707+ configuration : dict [str , Any ],
708+ ) -> pandas .DataFrame :
709+ @retry (
710+ reraise = True ,
711+ before = before_log ("Do get" , self ._logger , logging .DEBUG ),
712+ retry = self ._retry_config .retry ,
713+ stop = self ._retry_config .stop ,
714+ wait = self ._retry_config .wait ,
715+ )
716+ def safe_do_get () -> pandas .DataFrame :
717+ return self ._do_get (database , graph_name , procedure_name , concurrency , configuration )
718+
719+ return safe_do_get ()
720+
686721 def _do_get (
687722 self ,
688723 database : str ,
0 commit comments