diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 48b11d9342..2f7cc06239 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -65,6 +65,7 @@ from .types.type import TypeCode from .data_types import JsonObject, Interval from .transaction import BatchTransactionId, DefaultTransactionOptions +from .exceptions import SpannerError, wrap_with_request_id from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1.client import Client @@ -88,6 +89,9 @@ # google.cloud.spanner_v1 "__version__", "param_types", + # google.cloud.spanner_v1.exceptions + "SpannerError", + "wrap_with_request_id", # google.cloud.spanner_v1.client "Client", # google.cloud.spanner_v1.keyset diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 8a200fe812..17a47c4648 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -22,6 +22,7 @@ import threading import logging import uuid +from contextlib import contextmanager from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value @@ -34,8 +35,12 @@ from google.cloud.spanner_v1.types import ExecuteSqlRequest from google.cloud.spanner_v1.types import TransactionOptions from google.cloud.spanner_v1.data_types import JsonObject, Interval -from google.cloud.spanner_v1.request_id_header import with_request_id +from google.cloud.spanner_v1.request_id_header import ( + with_request_id, + with_request_id_metadata_only, +) from google.cloud.spanner_v1.types import TypeCode +from google.cloud.spanner_v1.exceptions import wrap_with_request_id from google.rpc.error_details_pb2 import RetryInfo @@ -568,7 +573,10 @@ def _retry_on_aborted_exception( ): """ Handles retry logic for Aborted exceptions, considering the deadline. + Also handles SpannerError that wraps Aborted exceptions. """ + from google.cloud.spanner_v1.exceptions import SpannerError + attempts = 0 while True: try: @@ -582,6 +590,17 @@ def _retry_on_aborted_exception( default_retry_delay=default_retry_delay, ) continue + except SpannerError as exc: + # Check if the wrapped error is Aborted + if isinstance(exc._error, Aborted): + _delay_until_retry( + exc._error, + deadline=deadline, + attempts=attempts, + default_retry_delay=default_retry_delay, + ) + continue + raise def _retry( @@ -600,10 +619,13 @@ def _retry( delay: The delay in seconds between retries. allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry. Passing allowed_exceptions as None will lead to retrying for all exceptions. + Also handles SpannerError wrapping allowed exceptions. Returns: The result of the function if it is successful, or raises the last exception if all retries fail. """ + from google.cloud.spanner_v1.exceptions import SpannerError + retries = 0 while retries <= retry_count: if retries > 0 and before_next_retry: @@ -612,14 +634,23 @@ def _retry( try: return func() except Exception as exc: - if ( - allowed_exceptions is None or exc.__class__ in allowed_exceptions - ) and retries < retry_count: + # Check if exception is allowed directly or wrapped in SpannerError + exc_to_check = exc + if isinstance(exc, SpannerError): + exc_to_check = exc._error + + is_allowed = ( + allowed_exceptions is None + or exc_to_check.__class__ in allowed_exceptions + ) + + if is_allowed and retries < retry_count: if ( allowed_exceptions is not None - and allowed_exceptions[exc.__class__] is not None + and exc_to_check.__class__ in allowed_exceptions + and allowed_exceptions[exc_to_check.__class__] is not None ): - allowed_exceptions[exc.__class__](exc) + allowed_exceptions[exc_to_check.__class__](exc_to_check) time.sleep(delay) delay = delay * 2 retries = retries + 1 @@ -767,9 +798,67 @@ def reset(self): def _metadata_with_request_id(*args, **kwargs): + """Return metadata with request ID header. + + This function returns only the metadata list (not a tuple), + maintaining backward compatibility with existing code. + + Args: + *args: Arguments to pass to with_request_id + **kwargs: Keyword arguments to pass to with_request_id + + Returns: + list: gRPC metadata with request ID header + """ + return with_request_id_metadata_only(*args, **kwargs) + + +def _metadata_with_request_id_and_req_id(*args, **kwargs): + """Return both metadata and request ID string. + + This is used when we need to augment errors with the request ID. + + Args: + *args: Arguments to pass to with_request_id + **kwargs: Keyword arguments to pass to with_request_id + + Returns: + tuple: (metadata, request_id) + """ return with_request_id(*args, **kwargs) +def _augment_error_with_request_id(error, request_id=None): + """Augment an error with request ID information. + + Args: + error: The error to augment (typically GoogleAPICallError) + request_id (str): The request ID to include + + Returns: + The augmented error with request ID information + """ + return wrap_with_request_id(error, request_id) + + +@contextmanager +def _augment_errors_with_request_id(request_id): + """Context manager to augment exceptions with request ID. + + Args: + request_id (str): The request ID to include in exceptions + + Yields: + None + """ + try: + yield + except Exception as exc: + augmented = _augment_error_with_request_id(exc, request_id) + # Use exception chaining to preserve the original exception + raise augmented from exc + + def _merge_Transaction_Options( defaultTransactionOptions: TransactionOptions, mergeTransactionOptions: TransactionOptions, diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0792e600dc..e70d214783 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -252,20 +252,22 @@ def wrapped_method(): max_commit_delay=max_commit_delay, request_options=request_options, ) + # This code is retried due to ABORTED, hence nth_request + # should be increased. attempt can only be increased if + # we encounter UNAVAILABLE or INTERNAL. + call_metadata, error_augmenter = database.with_error_augmentation( + getattr(database, "_next_nth_request", 0), + 1, + metadata, + span, + ) commit_method = functools.partial( api.commit, request=commit_request, - metadata=database.metadata_with_request_id( - # This code is retried due to ABORTED, hence nth_request - # should be increased. attempt can only be increased if - # we encounter UNAVAILABLE or INTERNAL. - getattr(database, "_next_nth_request", 0), - 1, - metadata, - span, - ), + metadata=call_metadata, ) - return commit_method() + with error_augmenter: + return commit_method() response = _retry_on_aborted_exception( wrapped_method, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 33c442602c..c870b6f8ea 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -25,7 +25,6 @@ import google.auth.credentials from google.api_core.retry import Retry -from google.api_core.retry import if_exception_type from google.cloud.exceptions import NotFound from google.api_core.exceptions import Aborted from google.api_core import gapic_v1 @@ -55,6 +54,8 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, + _augment_errors_with_request_id, + _metadata_with_request_id_and_req_id, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -496,6 +497,66 @@ def metadata_with_request_id( span, ) + def metadata_and_request_id( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + """Return metadata and request ID string. + + This method returns both the gRPC metadata with request ID header + and the request ID string itself, which can be used to augment errors. + + Args: + nth_request: The request sequence number + nth_attempt: The attempt number (for retries) + prior_metadata: Prior metadata to include + span: Optional span for tracing + + Returns: + tuple: (metadata_list, request_id_string) + """ + if span is None: + span = get_current_span() + + return _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + """Context manager for gRPC calls with error augmentation. + + This context manager provides both metadata with request ID and + automatically augments any exceptions with the request ID. + + Args: + nth_request: The request sequence number + nth_attempt: The attempt number (for retries) + prior_metadata: Prior metadata to include + span: Optional span for tracing + + Yields: + tuple: (metadata_list, context_manager) + """ + if span is None: + span = get_current_span() + + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + + return metadata, _augment_errors_with_request_id(request_id) + def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented @@ -783,16 +844,18 @@ def execute_pdml(): try: add_span_event(span, "Starting BeginTransaction") - txn = api.begin_transaction( - session=session.name, - options=txn_options, - metadata=self.metadata_with_request_id( - self._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = self.with_error_augmentation( + self._next_nth_request, + 1, + metadata, + span, ) + with error_augmenter: + txn = api.begin_transaction( + session=session.name, + options=txn_options, + metadata=call_metadata, + ) txn_selector = TransactionSelector(id=txn.id) @@ -2052,7 +2115,8 @@ def _retry_on_aborted(func, retry_config): """Helper for :meth:`Database.execute_partitioned_dml`. Wrap function in a Retry that will retry on Aborted exceptions - with the retry config specified. + with the retry config specified. Also handles SpannerError that + wraps Aborted exceptions. :type func: callable :param func: the function to be retried on Aborted exceptions @@ -2060,5 +2124,15 @@ def _retry_on_aborted(func, retry_config): :type retry_config: Retry :param retry_config: retry object with the settings to be used """ - retry = retry_config.with_predicate(if_exception_type(Aborted)) + from google.cloud.spanner_v1.exceptions import SpannerError + + def _is_aborted_or_wrapped_aborted(exc): + """Check if exception is Aborted or SpannerError wrapping Aborted.""" + if isinstance(exc, Aborted): + return True + if isinstance(exc, SpannerError) and isinstance(exc._error, Aborted): + return True + return False + + retry = retry_config.with_predicate(_is_aborted_or_wrapped_aborted) return retry(func) diff --git a/google/cloud/spanner_v1/exceptions.py b/google/cloud/spanner_v1/exceptions.py new file mode 100644 index 0000000000..5abdc9416c --- /dev/null +++ b/google/cloud/spanner_v1/exceptions.py @@ -0,0 +1,91 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner exception classes with request ID support.""" + +from google.api_core.exceptions import GoogleAPICallError + + +class SpannerError(Exception): + """Base Spanner exception that includes request ID. + + This class wraps an error (typically GoogleAPICallError) to add request ID information + for better debugging and error correlation. + + Args: + error: The original error to wrap (typically GoogleAPICallError) + request_id (str): The request ID associated with this error + """ + + def __init__(self, error, request_id=None): + self._error = error + self._request_id = request_id or "" + # Initialize the base Exception with the message + super().__init__(str(self)) + + @property + def request_id(self): + return self._request_id + + @property + def wrapped_error(self): + """Returns the wrapped error.""" + return self._error + + def is_error_type(self, error_type): + """Check if the wrapped error is of a specific type. + + Args: + error_type: The exception class to check against. + + Returns: + bool: True if the wrapped error is an instance of error_type. + """ + return isinstance(self._error, error_type) + + def __str__(self): + s = str(self._error) + if self._request_id: + s = f"{s}, request_id = {self._request_id!r}" + return s + + def __repr__(self): + return f"SpannerError(error={self._error!r}, request_id={self._request_id!r})" + + def __getattr__(self, name): + if name == "request_id": + return self._request_id + return getattr(self._error, name) + + +def wrap_with_request_id(error, request_id=None): + """Add request ID information to a GoogleAPICallError. + + This function adds request_id as an attribute to the exception rather than + wrapping it in a new type, preserving the original exception type for + exception handling compatibility. + + Args: + error: The error to augment. If not a GoogleAPICallError, returns as-is + request_id (str): The request ID to include + + Returns: + The original error with request_id attribute added (if GoogleAPICallError + and request_id is provided), otherwise returns the original error unchanged. + For GoogleAPICallError, the error is wrapped in SpannerError to include + request_id in the string representation. + """ + if isinstance(error, GoogleAPICallError) and request_id: + return SpannerError(error, request_id) + return error diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index a75c13cb7a..348a01e940 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -259,15 +259,17 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + database._next_nth_request, + 1, + metadata, + span, ) + with error_augmenter: + resp = api.batch_create_sessions( + request=request, + metadata=call_metadata, + ) add_span_event( span, @@ -570,15 +572,17 @@ def bind(self, database): ) as span, MetricsCapture(): returned_session_count = 0 while returned_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + database._next_nth_request, + 1, + metadata, + span, ) + with error_augmenter: + resp = api.batch_create_sessions( + request=request, + metadata=call_metadata, + ) add_span_event( span, diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py index 95c25b94f7..fb84e56100 100644 --- a/google/cloud/spanner_v1/request_id_header.py +++ b/google/cloud/spanner_v1/request_id_header.py @@ -43,6 +43,19 @@ def with_request_id( all_metadata = (other_metadata or []).copy() all_metadata.append((REQ_ID_HEADER_KEY, req_id)) + if span: + span.set_attribute(X_GOOG_SPANNER_REQUEST_ID_SPAN_ATTR, req_id) + + return all_metadata, req_id + + +def with_request_id_metadata_only( + client_id, channel_id, nth_request, attempt, other_metadata=[], span=None +): + req_id = build_request_id(client_id, channel_id, nth_request, attempt) + all_metadata = (other_metadata or []).copy() + all_metadata.append((REQ_ID_HEADER_KEY, req_id)) + if span: span.set_attribute(X_GOOG_SPANNER_REQUEST_ID_SPAN_ATTR, req_id) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 4c29014e15..c09a68023f 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -25,13 +25,14 @@ from google.api_core.gapic_v1 import method from google.cloud.spanner_v1._helpers import _delay_until_retry from google.cloud.spanner_v1._helpers import _get_retry_delay - -from google.cloud.spanner_v1 import ExecuteSqlRequest -from google.cloud.spanner_v1 import CreateSessionRequest from google.cloud.spanner_v1._helpers import ( _metadata_with_prefix, _metadata_with_leader_aware_routing, ) +from google.cloud.spanner_v1.exceptions import SpannerError + +from google.cloud.spanner_v1 import ExecuteSqlRequest +from google.cloud.spanner_v1 import CreateSessionRequest from google.cloud.spanner_v1._opentelemetry_tracing import ( add_span_event, get_current_span, @@ -185,6 +186,7 @@ def create(self): if self._is_multiplexed else "CloudSpanner.CreateSession" ) + nth_request = database._next_nth_request with trace_call( span_name, self, @@ -192,15 +194,14 @@ def create(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - session_pb = api.create_session( - request=create_session_request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span ) + with error_augmenter: + session_pb = api.create_session( + request=create_session_request, + metadata=call_metadata, + ) self._session_id = session_pb.name.split("/")[-1] def exists(self): @@ -235,26 +236,26 @@ def exists(self): ) observability_options = getattr(self._database, "observability_options", None) + nth_request = database._next_nth_request with trace_call( "CloudSpanner.GetSession", self, observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - try: - api.get_session( - name=self.name, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), - ) - span.set_attribute("session_found", True) - except NotFound: - span.set_attribute("session_found", False) - return False + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span + ) + with error_augmenter: + try: + api.get_session( + name=self.name, + metadata=call_metadata, + ) + span.set_attribute("session_found", True) + except NotFound: + span.set_attribute("session_found", False) + return False return True @@ -288,6 +289,7 @@ def delete(self): api = database.spanner_api metadata = _metadata_with_prefix(database.name) observability_options = getattr(self._database, "observability_options", None) + nth_request = database._next_nth_request with trace_call( "CloudSpanner.DeleteSession", self, @@ -298,15 +300,14 @@ def delete(self): observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): - api.delete_session( - name=self.name, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span ) + with error_augmenter: + api.delete_session( + name=self.name, + metadata=call_metadata, + ) def ping(self): """Ping the session to keep it alive by executing "SELECT 1". @@ -318,18 +319,19 @@ def ping(self): database = self._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) + nth_request = database._next_nth_request with trace_call("CloudSpanner.Session.ping", self) as span: - request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") - api.execute_sql( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, - 1, - _metadata_with_prefix(database.name), - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, 1, metadata, span ) + with error_augmenter: + request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") + api.execute_sql( + request=request, + metadata=call_metadata, + ) def snapshot(self, **kw): """Create a snapshot to perform a set of reads with shared staleness. @@ -570,10 +572,20 @@ def run_in_transaction(self, func, *args, **kw): try: return_value = func(txn, *args, **kw) - except Aborted as exc: + except (Aborted, SpannerError) as exc: + # Handle both raw Aborted and SpannerError wrapping Aborted + aborted_exc = exc + if isinstance(exc, SpannerError) and isinstance( + exc._error, Aborted + ): + aborted_exc = exc._error + elif not isinstance(exc, Aborted): + # SpannerError wrapping non-Aborted, re-raise + raise + previous_transaction_id = txn._transaction_id delay_seconds = _get_retry_delay( - exc.errors[0], + aborted_exc.errors[0], attempts, default_retry_delay=default_retry_delay, ) @@ -585,7 +597,10 @@ def run_in_transaction(self, func, *args, **kw): attributes, ) _delay_until_retry( - exc, deadline, attempts, default_retry_delay=default_retry_delay + aborted_exc, + deadline, + attempts, + default_retry_delay=default_retry_delay, ) continue @@ -613,10 +628,20 @@ def run_in_transaction(self, func, *args, **kw): max_commit_delay=max_commit_delay, ) - except Aborted as exc: + except (Aborted, SpannerError) as exc: + # Handle both raw Aborted and SpannerError wrapping Aborted + aborted_exc = exc + if isinstance(exc, SpannerError) and isinstance( + exc._error, Aborted + ): + aborted_exc = exc._error + elif not isinstance(exc, Aborted): + # SpannerError wrapping non-Aborted, re-raise + raise + previous_transaction_id = txn._transaction_id delay_seconds = _get_retry_delay( - exc.errors[0], + aborted_exc.errors[0], attempts, default_retry_delay=default_retry_delay, ) @@ -628,7 +653,10 @@ def run_in_transaction(self, func, *args, **kw): attributes, ) _delay_until_retry( - exc, deadline, attempts, default_retry_delay=default_retry_delay + aborted_exc, + deadline, + attempts, + default_retry_delay=default_retry_delay, ) except GoogleAPICallError: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 89cbc9fe88..af34f3ac2f 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -47,6 +47,7 @@ _check_rst_stream_error, _SessionWrapper, AtomicCounter, + _augment_error_with_request_id, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call, add_span_event from google.cloud.spanner_v1.streamed import StreamedResultSet @@ -103,6 +104,7 @@ def _restart_on_unavailable( iterator = None attempt = 1 nth_request = getattr(request_id_manager, "_next_nth_request", 0) + current_request_id = None while True: try: @@ -115,14 +117,18 @@ def _restart_on_unavailable( observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): + ( + call_metadata, + current_request_id, + ) = request_id_manager.metadata_and_request_id( + nth_request, + attempt, + metadata, + span, + ) iterator = method( request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), + metadata=call_metadata, ) # Add items from iterator to buffer. @@ -158,14 +164,18 @@ def _restart_on_unavailable( transaction_selector = transaction._build_transaction_selector_pb() request.transaction = transaction_selector attempt += 1 + ( + call_metadata, + current_request_id, + ) = request_id_manager.metadata_and_request_id( + nth_request, + attempt, + metadata, + span, + ) iterator = method( request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), + metadata=call_metadata, ) continue @@ -175,7 +185,7 @@ def _restart_on_unavailable( for resumable_message in _STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES ) if not resumable_error: - raise + raise _augment_error_with_request_id(exc, current_request_id) del item_buffer[:] with trace_call( trace_name, @@ -189,17 +199,25 @@ def _restart_on_unavailable( transaction_selector = transaction._build_transaction_selector_pb() attempt += 1 request.transaction = transaction_selector + ( + call_metadata, + current_request_id, + ) = request_id_manager.metadata_and_request_id( + nth_request, + attempt, + metadata, + span, + ) iterator = method( request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), + metadata=call_metadata, ) continue + except Exception as exc: + # Augment any other exception with the request ID + raise _augment_error_with_request_id(exc, current_request_id) + if len(item_buffer) == 0: break @@ -961,17 +979,19 @@ def wrapped_method(): begin_transaction_request = BeginTransactionRequest( **begin_request_kwargs ) + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + attempt.increment(), + metadata, + span, + ) begin_transaction_method = functools.partial( api.begin_transaction, request=begin_transaction_request, - metadata=database.metadata_with_request_id( - nth_request, - attempt.increment(), - metadata, - span, - ), + metadata=call_metadata, ) - return begin_transaction_method() + with error_augmenter: + return begin_transaction_method() def before_next_retry(nth_retry, delay_in_seconds): add_span_event( diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index de8b421840..413ac0af1f 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -185,18 +185,20 @@ def rollback(self) -> None: def wrapped_method(*args, **kwargs): attempt.increment() + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + attempt.value, + metadata, + span, + ) rollback_method = functools.partial( api.rollback, session=session.name, transaction_id=self._transaction_id, - metadata=database.metadata_with_request_id( - nth_request, - attempt.value, - metadata, - span, - ), + metadata=call_metadata, ) - return rollback_method(*args, **kwargs) + with error_augmenter: + return rollback_method(*args, **kwargs) _retry( wrapped_method, @@ -298,17 +300,19 @@ def wrapped_method(*args, **kwargs): if is_multiplexed and self._precommit_token is not None: commit_request_args["precommit_token"] = self._precommit_token + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + attempt.value, + metadata, + span, + ) commit_method = functools.partial( api.commit, request=CommitRequest(**commit_request_args), - metadata=database.metadata_with_request_id( - nth_request, - attempt.value, - metadata, - span, - ), + metadata=call_metadata, ) - return commit_method(*args, **kwargs) + with error_augmenter: + return commit_method(*args, **kwargs) commit_retry_event_name = "Transaction Commit Attempt Failed. Retrying" @@ -335,18 +339,20 @@ def before_next_retry(nth_retry, delay_in_seconds): if commit_response_pb._pb.HasField("precommit_token"): add_span_event(span, commit_retry_event_name) nth_request = database._next_nth_request - commit_response_pb = api.commit( - request=CommitRequest( - precommit_token=commit_response_pb.precommit_token, - **common_commit_request_args, - ), - metadata=database.metadata_with_request_id( - nth_request, - 1, - metadata, - span, - ), + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, + 1, + metadata, + span, ) + with error_augmenter: + commit_response_pb = api.commit( + request=CommitRequest( + precommit_token=commit_response_pb.precommit_token, + **common_commit_request_args, + ), + metadata=call_metadata, + ) add_span_event(span, "Commit Done") @@ -510,16 +516,18 @@ def execute_update( def wrapped_method(*args, **kwargs): attempt.increment() + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, attempt.value, metadata + ) execute_sql_method = functools.partial( api.execute_sql, request=execute_sql_request, - metadata=database.metadata_with_request_id( - nth_request, attempt.value, metadata - ), + metadata=call_metadata, retry=retry, timeout=timeout, ) - return execute_sql_method(*args, **kwargs) + with error_augmenter: + return execute_sql_method(*args, **kwargs) result_set_pb: ResultSet = self._execute_request( wrapped_method, @@ -658,16 +666,18 @@ def batch_update( def wrapped_method(*args, **kwargs): attempt.increment() + call_metadata, error_augmenter = database.with_error_augmentation( + nth_request, attempt.value, metadata + ) execute_batch_dml_method = functools.partial( api.execute_batch_dml, request=execute_batch_dml_request, - metadata=database.metadata_with_request_id( - nth_request, attempt.value, metadata - ), + metadata=call_metadata, retry=retry, timeout=timeout, ) - return execute_batch_dml_method(*args, **kwargs) + with error_augmenter: + return execute_batch_dml_method(*args, **kwargs) response_pb: ExecuteBatchDmlResponse = self._execute_request( wrapped_method, diff --git a/noxfile.py b/noxfile.py index e85fba3c54..2cd172c587 100644 --- a/noxfile.py +++ b/noxfile.py @@ -558,6 +558,7 @@ def prerelease_deps(session, protobuf_implementation, database_dialect): # dependency of google-auth "cffi", "cryptography", + "cachetools", ] for dep in prerel_deps: diff --git a/tests/mockserver_tests/test_aborted_transaction.py b/tests/mockserver_tests/test_aborted_transaction.py index a1f9f1ba1e..bd01ba8f4a 100644 --- a/tests/mockserver_tests/test_aborted_transaction.py +++ b/tests/mockserver_tests/test_aborted_transaction.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random - from google.cloud.spanner_v1 import ( BeginTransactionRequest, CommitRequest, @@ -22,6 +20,7 @@ ) from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer from google.cloud.spanner_v1.transaction import Transaction +from google.cloud.spanner_v1.exceptions import SpannerError from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, add_error, @@ -33,8 +32,23 @@ from test_utils import retry from google.cloud.spanner_v1.database_sessions_manager import TransactionType + +def _is_aborted_error(exc): + """Check if exception is Aborted or SpannerError wrapping Aborted.""" + if isinstance(exc, exceptions.Aborted): + return True + if isinstance(exc, SpannerError) and isinstance(exc._error, exceptions.Aborted): + return True + return False + + +# Retry on both Aborted and SpannerError (which wraps Aborted with request ID) retry_maybe_aborted_txn = retry.RetryErrors( - exceptions.Aborted, max_tries=5, delay=0, backoff=1 + (exceptions.Aborted, SpannerError), + error_predicate=_is_aborted_error, + max_tries=5, + delay=0, + backoff=1, ) @@ -119,17 +133,21 @@ def test_batch_commit_aborted(self): TransactionType.READ_WRITE, ) - @retry_maybe_aborted_txn def test_retry_helper(self): - # Randomly add an Aborted error for the Commit method on the mock server. - if random.random() < 0.5: - add_error(SpannerServicer.Commit.__name__, aborted_status()) - session = self.database.session() - session.create() - transaction = session.transaction() - transaction.begin() - transaction.insert("my_table", ["col1, col2"], [{"col1": 1, "col2": "One"}]) - transaction.commit() + # Add an Aborted error for the Commit method on the mock server. + # The error is popped after the first use, so the retry will succeed. + add_error(SpannerServicer.Commit.__name__, aborted_status()) + + @retry_maybe_aborted_txn + def do_commit(): + session = self.database.session() + session.create() + transaction = session.transaction() + transaction.begin() + transaction.insert("my_table", ["col1, col2"], [{"col1": 1, "col2": "One"}]) + transaction.commit() + + do_commit() def _insert_mutations(transaction: Transaction): diff --git a/tests/mockserver_tests/test_dbapi_isolation_level.py b/tests/mockserver_tests/test_dbapi_isolation_level.py index 679740969a..6fb9c5fc74 100644 --- a/tests/mockserver_tests/test_dbapi_isolation_level.py +++ b/tests/mockserver_tests/test_dbapi_isolation_level.py @@ -18,6 +18,7 @@ BeginTransactionRequest, TransactionOptions, ) +from google.cloud.spanner_v1.exceptions import SpannerError from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, add_update_count, @@ -146,5 +147,8 @@ def test_begin_isolation_level(self): def test_begin_invalid_isolation_level(self): connection = Connection(self.instance, self.database) with connection.cursor() as cursor: - with self.assertRaises(Unknown): + # The Unknown exception is now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: cursor.execute("begin isolation level does_not_exist") + # Verify that the SpannerError wraps an Unknown exception + self.assertIsInstance(context.exception._error, Unknown) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index e8297030eb..76ae692bb0 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -41,6 +41,8 @@ from google.cloud.spanner_v1._helpers import ( AtomicCounter, _metadata_with_request_id, + _augment_errors_with_request_id, + _metadata_with_request_id_and_req_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID @@ -205,6 +207,8 @@ def test_commit_already_committed(self): return_value="global", ) def test_commit_grpc_error(self, mock_region): + from google.cloud.spanner_v1.exceptions import SpannerError + keys = [[0], [1], [2]] keyset = KeySet(keys=keys) database = _Database() @@ -213,9 +217,13 @@ def test_commit_grpc_error(self, mock_region): batch = self._make_one(session) batch.delete(TABLE_NAME, keyset=keyset) - with self.assertRaises(Unknown): + # The exception is wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: batch.commit() + # Verify that the SpannerError wraps an Unknown exception + self.assertIsInstance(context.exception._error, Unknown) + req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertSpanAttributes( "CloudSpanner.Batch.commit", @@ -281,6 +289,8 @@ def test_commit_ok(self, mock_region): def test_aborted_exception_on_commit_with_retries(self): # Test case to verify that an Aborted exception is raised when # batch.commit() is called and the transaction is aborted internally. + # The exception is wrapped in SpannerError with request ID. + from google.cloud.spanner_v1.exceptions import SpannerError database = _Database() # Setup the spanner API which throws Aborted exception when calling commit API. @@ -294,12 +304,13 @@ def test_aborted_exception_on_commit_with_retries(self): batch = self._make_one(session) batch.insert(TABLE_NAME, COLUMNS, VALUES) - # Assertion: Ensure that calling batch.commit() raises the Aborted exception - with self.assertRaises(Aborted) as context: + # Assertion: Ensure that calling batch.commit() raises SpannerError wrapping Aborted + with self.assertRaises(SpannerError) as context: batch.commit(timeout_secs=0.1, default_retry_delay=0) - # Verify additional details about the exception - self.assertEqual(str(context.exception), "409 Transaction was aborted") + # Verify that the SpannerError wraps an Aborted exception and includes request ID + self.assertIn("409 Transaction was aborted", str(context.exception)) + self.assertIn("request_id", str(context.exception)) self.assertGreater( api.commit.call_count, 1, "commit should be called more than once" ) @@ -821,6 +832,19 @@ def metadata_with_request_id( span, ) + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + @property def _channel_id(self): return 1 diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 92001fb52c..09e68da2d8 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -34,6 +34,8 @@ from google.cloud.spanner_v1._helpers import ( AtomicCounter, _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID from google.cloud.spanner_v1.session import Session @@ -2255,6 +2257,7 @@ def test_context_mgr_w_aborted_commit_status(self): from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.batch import Batch + from google.cloud.spanner_v1.exceptions import SpannerError database = _Database(self.DATABASE_NAME) database.log_commit_stats = True @@ -2265,12 +2268,16 @@ def test_context_mgr_w_aborted_commit_status(self): pool.put(session) checkout = self._make_one(database, timeout_secs=0.1, default_retry_delay=0) - with self.assertRaises(Aborted): + # Exception is wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: with checkout as batch: self.assertIsNone(pool._session) self.assertIsInstance(batch, Batch) self.assertIs(batch._session, session) + # Verify that the SpannerError wraps an Aborted exception + self.assertIsInstance(context.exception._error, Aborted) + self.assertIs(pool._session, session) expected_txn_options = TransactionOptions(read_write={}) @@ -3635,6 +3642,19 @@ def metadata_with_request_id( def _channel_id(self): return 1 + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + class _Pool(object): _bound = None diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index c6156b5e8c..34fa2dc5f3 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -21,6 +21,7 @@ from google.api_core.exceptions import BadRequest, FailedPrecondition from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager from google.cloud.spanner_v1.database_sessions_manager import TransactionType +from google.cloud.spanner_v1.exceptions import SpannerError from tests._builders import build_database @@ -208,16 +209,22 @@ def test_exception_bad_request(self): api = manager._database.spanner_api api.create_session.side_effect = BadRequest("") - with self.assertRaises(BadRequest): + # Exception is wrapped with request ID + with self.assertRaises(SpannerError) as cm: manager.get_session(TransactionType.READ_ONLY) + # Verify the wrapped exception contains the original error type + self.assertIsInstance(cm.exception._error, BadRequest) def test_exception_failed_precondition(self): manager = self._manager api = manager._database.spanner_api api.create_session.side_effect = FailedPrecondition("") - with self.assertRaises(FailedPrecondition): + # Exception is wrapped with request ID + with self.assertRaises(SpannerError) as cm: manager.get_session(TransactionType.READ_ONLY) + # Verify the wrapped exception contains the original error type + self.assertIsInstance(cm.exception._error, FailedPrecondition) def test__use_multiplexed_read_only(self): transaction_type = TransactionType.READ_ONLY diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000000..81bfbbc481 --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,113 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Spanner exception handling with request IDs.""" + +import unittest + +from google.api_core.exceptions import GoogleAPICallError +from google.api_core.exceptions import Aborted +from google.cloud.spanner_v1.exceptions import SpannerError, wrap_with_request_id + + +class TestSpannerException(unittest.TestCase): + """Test Spanner exception class and error wrapping.""" + + def test_wrap_with_request_id_with_google_api_error(self): + """Test wrapping GoogleAPICallError with request ID.""" + error = Aborted("Transaction aborted") + request_id = "1.12345.1.0.1.1" + + wrapped_error = wrap_with_request_id(error, request_id) + + self.assertIsInstance(wrapped_error, SpannerError) + self.assertEqual(wrapped_error.request_id, request_id) + self.assertIn(request_id, str(wrapped_error)) + self.assertIn("Transaction aborted", str(wrapped_error)) + + def test_wrap_with_request_id_without_request_id(self): + """Test wrapping GoogleAPICallError without request ID.""" + error = Aborted("Transaction aborted") + + wrapped_error = wrap_with_request_id(error) + + self.assertIs(wrapped_error, error) + self.assertNotIsInstance(wrapped_error, SpannerError) + + def test_wrap_with_request_id_with_non_google_api_error(self): + """Test wrapping non-GoogleAPICallError with request ID.""" + error = Exception("Some other error") + request_id = "1.12345.1.0.1.1" + + wrapped_error = wrap_with_request_id(error, request_id) + + self.assertIs(wrapped_error, error) + self.assertNotIsInstance(wrapped_error, SpannerError) + + def test_spanner_error_str_includes_request_id(self): + """Test that SpannerError string representation includes request ID.""" + error = Aborted("Transaction aborted") + request_id = "1.12345.1.0.1.1" + + wrapped_error = SpannerError(error, request_id) + + error_str = str(wrapped_error) + self.assertIn(request_id, error_str) + self.assertIn("Transaction aborted", error_str) + + def test_spanner_error_str_without_request_id(self): + """Test SpannerError string representation without request ID.""" + error = Aborted("Transaction aborted") + + wrapped_error = SpannerError(error, "") + + error_str = str(wrapped_error) + self.assertNotIn("request_id", error_str) + self.assertIn("Transaction aborted", error_str) + + def test_spanner_error_repr(self): + """Test SpannerError repr.""" + error = Aborted("Transaction aborted") + request_id = "1.12345.1.0.1.1" + + wrapped_error = SpannerError(error, request_id) + + repr_str = repr(wrapped_error) + self.assertIn("SpannerError", repr_str) + self.assertIn(request_id, repr_str) + + def test_spanner_error_preserves_original_error_properties(self): + """Test that SpannerError preserves original error properties.""" + error = GoogleAPICallError("Error message") + request_id = "1.12345.1.0.1.1" + + wrapped_error = SpannerError(error, request_id) + + self.assertEqual(wrapped_error.message, "Error message") + self.assertEqual(wrapped_error.request_id, request_id) + + def test_spanner_error_getattr_delegates_to_original(self): + """Test that SpannerError delegates attributes to original error.""" + error = Aborted("Transaction aborted") + request_id = "1.12345.1.0.1.1" + + wrapped_error = SpannerError(error, request_id) + + self.assertEqual(wrapped_error.request_id, request_id) + self.assertEqual(wrapped_error.message, "Transaction aborted") + self.assertEqual(wrapped_error.response, None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index ec03e4350b..a44bb8c5a1 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -21,6 +21,8 @@ import mock from google.cloud.spanner_v1._helpers import ( _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, AtomicCounter, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID @@ -1450,6 +1452,19 @@ def metadata_with_request_id( def _channel_id(self): return 1 + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + class _Queue(object): _size = 1 diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8026c50c24..62a6fd0387 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -53,6 +53,7 @@ from google.protobuf.duration_pb2 import Duration from google.rpc.error_details_pb2 import RetryInfo from google.api_core.exceptions import Unknown, Aborted, NotFound, Cancelled +from google.cloud.spanner_v1.exceptions import SpannerError from google.protobuf.struct_pb2 import Struct, Value from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1 import DefaultTransactionOptions @@ -92,7 +93,11 @@ def inject_into_mock_database(mockdb): def metadata_with_request_id( nth_request, nth_attempt, prior_metadata=[], span=None ): - nth_req = nth_request.fget(mockdb) + # Handle both cases: nth_request as an integer or as a property descriptor + if isinstance(nth_request, int): + nth_req = nth_request + else: + nth_req = nth_request.fget(mockdb) return _metadata_with_request_id( nth_client_id, channel_id, @@ -104,11 +109,45 @@ def metadata_with_request_id( setattr(mockdb, "metadata_with_request_id", metadata_with_request_id) - @property - def _next_nth_request(self): - return self._nth_request.increment() + # Create a property-like object using type() to make it work with mock + type(mockdb)._next_nth_request = property( + lambda self: self._nth_request.increment() + ) + + # Use a closure to capture nth_client_id and channel_id + def make_with_error_augmentation(db_nth_client_id, db_channel_id): + def with_error_augmentation( + nth_request, nth_attempt, prior_metadata=[], span=None + ): + """Context manager for gRPC calls with error augmentation.""" + from google.cloud.spanner_v1._helpers import ( + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, + ) + + if span is None: + from google.cloud.spanner_v1._opentelemetry_tracing import ( + get_current_span, + ) + + span = get_current_span() + + metadata, request_id = _metadata_with_request_id_and_req_id( + db_nth_client_id, + db_channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) - setattr(mockdb, "_next_nth_request", _next_nth_request) + return metadata, _augment_errors_with_request_id(request_id) + + return with_error_augmentation + + mockdb.with_error_augmentation = make_with_error_augmentation( + nth_client_id, channel_id + ) return mockdb @@ -443,8 +482,11 @@ def test_create_error(self, mock_region): database.spanner_api = gax_api session = self._make_one(database) - with self.assertRaises(Unknown): + # Exception is wrapped with request ID + with self.assertRaises(SpannerError) as cm: session.create() + # Verify the wrapped exception contains the original error type + self.assertIsInstance(cm.exception._error, Unknown) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertSpanAttributes( @@ -547,8 +589,11 @@ def test_exists_error(self, mock_region): session = self._make_one(database) session._session_id = self.SESSION_ID - with self.assertRaises(Unknown): + # Exception is wrapped with request ID + with self.assertRaises(SpannerError) as cm: session.exists() + # Verify the wrapped exception contains the original error type + self.assertIsInstance(cm.exception._error, Unknown) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" gax_api.get_session.assert_called_once_with( @@ -625,7 +670,7 @@ def test_ping_miss(self, mock_region): session = self._make_one(database) session._session_id = self.SESSION_ID - with self.assertRaises(NotFound): + with self.assertRaises(SpannerError): session.ping() request = ExecuteSqlRequest( @@ -663,7 +708,7 @@ def test_ping_error(self, mock_region): session = self._make_one(database) session._session_id = self.SESSION_ID - with self.assertRaises(Unknown): + with self.assertRaises(SpannerError): session.ping() request = ExecuteSqlRequest( @@ -743,7 +788,7 @@ def test_delete_miss(self, mock_region): session = self._make_one(database) session._session_id = self.SESSION_ID - with self.assertRaises(NotFound): + with self.assertRaises(SpannerError): session.delete() req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" @@ -783,7 +828,7 @@ def test_delete_error(self, mock_region): session = self._make_one(database) session._session_id = self.SESSION_ID - with self.assertRaises(Unknown): + with self.assertRaises(SpannerError): session.delete() req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" @@ -1269,6 +1314,8 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_commit_error(self): + from google.cloud.spanner_v1.exceptions import SpannerError + TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] VALUES = [ @@ -1292,8 +1339,10 @@ def unit_of_work(txn, *args, **kw): called_with.append((txn, args, kw)) txn.insert(TABLE_NAME, COLUMNS, VALUES) - with self.assertRaises(Unknown): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: session.run_in_transaction(unit_of_work) + self.assertIsInstance(context.exception._error, Unknown) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] @@ -1628,6 +1677,8 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_abort_w_retry_metadata_deadline(self): + from google.cloud.spanner_v1.exceptions import SpannerError + RETRY_SECONDS = 1 RETRY_NANOS = 3456 transaction_pb = TransactionPB(id=TRANSACTION_ID) @@ -1661,8 +1712,10 @@ def _time(_results=[1, 1.5]): with mock.patch("time.time", _time): with mock.patch("time.sleep") as sleep_mock: - with self.assertRaises(Aborted): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: session.run_in_transaction(unit_of_work, "abc", timeout_secs=1) + self.assertIsInstance(context.exception._error, Aborted) sleep_mock.assert_not_called() @@ -1706,6 +1759,8 @@ def _time(_results=[1, 1.5]): ) def test_run_in_transaction_w_timeout(self): + from google.cloud.spanner_v1.exceptions import SpannerError + transaction_pb = TransactionPB(id=TRANSACTION_ID) aborted = _make_rpc_error(Aborted, trailing_metadata=[]) gax_api = self._make_spanner_api() @@ -1729,8 +1784,10 @@ def _time(_results=[1, 2, 4, 8]): with mock.patch("time.time", _time), mock.patch( "google.cloud.spanner_v1._helpers.random.random", return_value=0 ), mock.patch("time.sleep") as sleep_mock: - with self.assertRaises(Aborted): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: session.run_in_transaction(unit_of_work, timeout_secs=8) + self.assertIsInstance(context.exception._error, Aborted) # unpacking call args into list call_args = [call_[0][0] for call_ in sleep_mock.call_args_list] @@ -1911,6 +1968,8 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_commit_stats_error(self): + from google.cloud.spanner_v1.exceptions import SpannerError + transaction_pb = TransactionPB(id=TRANSACTION_ID) gax_api = self._make_spanner_api() gax_api.begin_transaction.return_value = transaction_pb @@ -1928,8 +1987,10 @@ def unit_of_work(txn, *args, **kw): txn.insert(TABLE_NAME, COLUMNS, VALUES) return 42 - with self.assertRaises(Unknown): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: session.run_in_transaction(unit_of_work, "abc", some_arg="def") + self.assertIsInstance(context.exception._error, Unknown) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 974cc8e75e..c65bb4b0d7 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -44,6 +44,8 @@ ) from google.cloud.spanner_v1._helpers import ( _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, AtomicCounter, ) from google.cloud.spanner_v1.param_types import INT64 @@ -282,6 +284,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): from google.api_core.exceptions import InternalServerError + from google.cloud.spanner_v1.exceptions import SpannerError ITEMS = ( self._make_item(0), @@ -297,8 +300,10 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): session = _Session(database) derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) - with self.assertRaises(InternalServerError): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: list(resumable) + self.assertIsInstance(context.exception._error, InternalServerError) restart.assert_called_once_with( request=request, metadata=[ @@ -356,6 +361,7 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): def test_iteration_w_raw_raising_non_retryable_internal_error(self): from google.api_core.exceptions import InternalServerError + from google.cloud.spanner_v1.exceptions import SpannerError FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN)) SECOND = (self._make_item(2),) # discarded after 503 @@ -371,8 +377,10 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): session = _Session(database) derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) - with self.assertRaises(InternalServerError): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: list(resumable) + self.assertIsInstance(context.exception._error, InternalServerError) restart.assert_called_once_with( request=request, metadata=[ @@ -532,6 +540,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): from google.api_core.exceptions import InternalServerError + from google.cloud.spanner_v1.exceptions import SpannerError FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN)) SECOND = (self._make_item(2), self._make_item(3)) @@ -546,8 +555,10 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): session = _Session(database) derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) - with self.assertRaises(InternalServerError): + # Errors are now wrapped in SpannerError with request ID + with self.assertRaises(SpannerError) as context: list(resumable) + self.assertIsInstance(context.exception._error, InternalServerError) restart.assert_called_once_with( request=request, metadata=[ @@ -2168,6 +2179,31 @@ def metadata_with_request_id( span, ) + def metadata_and_request_id( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + return _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + @property def _channel_id(self): return 1 diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index d1de23d2d0..ecd7d4fd86 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -42,6 +42,8 @@ _make_value_pb, _merge_query_options, _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID import mock @@ -1319,10 +1321,35 @@ def metadata_with_request_id( span, ) + def metadata_and_request_id( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + return _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + @property def _channel_id(self): return 1 + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + class _Session(object): _transaction = None diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 510251656e..405521509f 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -35,6 +35,8 @@ from google.cloud.spanner_v1._helpers import ( AtomicCounter, _metadata_with_request_id, + _metadata_with_request_id_and_req_id, + _augment_errors_with_request_id, ) from google.cloud.spanner_v1.batch import _make_write_pb from google.cloud.spanner_v1.database import Database @@ -1420,6 +1422,19 @@ def metadata_with_request_id( span, ) + def with_error_augmentation( + self, nth_request, nth_attempt, prior_metadata=[], span=None + ): + metadata, request_id = _metadata_with_request_id_and_req_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + span, + ) + return metadata, _augment_errors_with_request_id(request_id) + @property def _channel_id(self): return 1