2727import com .google .cloud .NoCredentials ;
2828import com .google .cloud .spanner .MockSpannerServiceImpl .SimulatedExecutionTime ;
2929import com .google .cloud .spanner .connection .AbstractMockServerTest ;
30- import com .google .common .collect .ImmutableSet ;
3130import com .google .spanner .v1 .BatchCreateSessionsRequest ;
3231import com .google .spanner .v1 .BeginTransactionRequest ;
3332import com .google .spanner .v1 .ExecuteSqlRequest ;
6463@ RunWith (JUnit4 .class )
6564public class RetryOnDifferentGrpcChannelMockServerTest extends AbstractMockServerTest {
6665 private static final Map <String , Set <InetSocketAddress >> SERVER_ADDRESSES = new HashMap <>();
66+ private static final Map <String , Set <Long >> CHANNEL_HINTS = new HashMap <>();
6767
6868 @ BeforeClass
6969 public static void startStaticServer () throws IOException {
@@ -79,6 +79,7 @@ public static void removeSystemProperty() {
7979 @ After
8080 public void clearRequests () {
8181 SERVER_ADDRESSES .clear ();
82+ CHANNEL_HINTS .clear ();
8283 mockSpanner .clearRequests ();
8384 mockSpanner .removeAllExecutionTimes ();
8485 }
@@ -91,6 +92,7 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(
9192 Metadata metadata ,
9293 ServerCallHandler <ReqT , RespT > serverCallHandler ) {
9394 Attributes attributes = serverCall .getAttributes ();
95+ String methodName = serverCall .getMethodDescriptor ().getFullMethodName ();
9496 //noinspection unchecked,deprecation
9597 Attributes .Key <InetSocketAddress > key =
9698 (Attributes .Key <InetSocketAddress >)
@@ -102,11 +104,26 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(
102104 InetSocketAddress address = attributes .get (key );
103105 synchronized (SERVER_ADDRESSES ) {
104106 Set <InetSocketAddress > addresses =
105- SERVER_ADDRESSES .getOrDefault (
106- serverCall .getMethodDescriptor ().getFullMethodName (), new HashSet <>());
107+ SERVER_ADDRESSES .getOrDefault (methodName , new HashSet <>());
107108 addresses .add (address );
108- SERVER_ADDRESSES .putIfAbsent (
109- serverCall .getMethodDescriptor ().getFullMethodName (), addresses );
109+ SERVER_ADDRESSES .putIfAbsent (methodName , addresses );
110+ }
111+ }
112+ String requestId = metadata .get (XGoogSpannerRequestId .REQUEST_HEADER_KEY );
113+ if (requestId != null ) {
114+ // REQUEST_ID format: version.randProcessId.nthClientId.nthChannelId.nthRequest.attempt
115+ String [] parts = requestId .split ("\\ ." );
116+ if (parts .length >= 6 ) {
117+ try {
118+ long channelHint = Long .parseLong (parts [3 ]);
119+ synchronized (CHANNEL_HINTS ) {
120+ Set <Long > hints = CHANNEL_HINTS .getOrDefault (methodName , new HashSet <>());
121+ hints .add (channelHint );
122+ CHANNEL_HINTS .putIfAbsent (methodName , hints );
123+ }
124+ } catch (NumberFormatException ignore ) {
125+ // Ignore malformed header values in tests.
126+ }
110127 }
111128 }
112129 return serverCallHandler .startCall (serverCall , metadata );
@@ -157,8 +174,8 @@ public void testReadWriteTransaction_retriesOnNewChannel() {
157174 assertNotEquals (requests .get (0 ).getSession (), requests .get (1 ).getSession ());
158175 assertEquals (
159176 2 ,
160- SERVER_ADDRESSES
161- .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , ImmutableSet . of ())
177+ CHANNEL_HINTS
178+ .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , new HashSet <> ())
162179 .size ());
163180 }
164181
@@ -201,8 +218,8 @@ public void testReadWriteTransaction_stopsRetrying() {
201218 assertEquals (numChannels , sessions .size ());
202219 assertEquals (
203220 numChannels ,
204- SERVER_ADDRESSES
205- .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , ImmutableSet . of ())
221+ CHANNEL_HINTS
222+ .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , new HashSet <> ())
206223 .size ());
207224 }
208225 }
@@ -275,8 +292,8 @@ public void testDenyListedChannelIsCleared() {
275292 assertEquals (numChannels + 1 , sessions .size ());
276293 assertEquals (
277294 numChannels ,
278- SERVER_ADDRESSES
279- .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , ImmutableSet . of ())
295+ CHANNEL_HINTS
296+ .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , new HashSet <> ())
280297 .size ());
281298 assertEquals (numChannels , mockSpanner .countRequestsOfType (BatchCreateSessionsRequest .class ));
282299 }
@@ -303,11 +320,11 @@ public void testSingleUseQuery_retriesOnNewChannel() {
303320 List <ExecuteSqlRequest > requests = mockSpanner .getRequestsOfType (ExecuteSqlRequest .class );
304321 // The requests use the same multiplexed session.
305322 assertEquals (requests .get (0 ).getSession (), requests .get (1 ).getSession ());
306- // The requests use two different gRPC channels .
323+ // The requests use two different channel hints (which may map to same physical channel) .
307324 assertEquals (
308325 2 ,
309- SERVER_ADDRESSES
310- .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , ImmutableSet . of ())
326+ CHANNEL_HINTS
327+ .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , new HashSet <> ())
311328 .size ());
312329 }
313330
@@ -327,19 +344,19 @@ public void testSingleUseQuery_stopsRetrying() {
327344 assertEquals (ErrorCode .DEADLINE_EXCEEDED , exception .getErrorCode ());
328345 }
329346 int numChannels = spanner .getOptions ().getNumChannels ();
330- assertEquals (numChannels , mockSpanner .countRequestsOfType (ExecuteSqlRequest .class ));
331347 List <ExecuteSqlRequest > requests = mockSpanner .getRequestsOfType (ExecuteSqlRequest .class );
332348 // The requests use the same multiplexed session.
333349 String session = requests .get (0 ).getSession ();
334350 for (ExecuteSqlRequest request : requests ) {
335351 assertEquals (session , request .getSession ());
336352 }
337- // The requests use all gRPC channels.
338- assertEquals (
339- numChannels ,
340- SERVER_ADDRESSES
341- .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , ImmutableSet .of ())
342- .size ());
353+ // Each attempt, including retries, must use a distinct channel hint.
354+ int totalRequests = mockSpanner .countRequestsOfType (ExecuteSqlRequest .class );
355+ int distinctHints =
356+ CHANNEL_HINTS
357+ .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , new HashSet <>())
358+ .size ();
359+ assertEquals (totalRequests , distinctHints );
343360 }
344361 }
345362
0 commit comments