@@ -4,8 +4,10 @@ import (
44 "context"
55 "fmt"
66 "sort"
7+ "sync/atomic"
78
89 "google.golang.org/grpc"
10+ grpcCodes "google.golang.org/grpc/codes"
911
1012 "github.com/ydb-platform/ydb-go-sdk/v3/config"
1113 balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
@@ -19,7 +21,6 @@ import (
1921 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
2022 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
2123 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
22- "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
2324 "github.com/ydb-platform/ydb-go-sdk/v3/retry"
2425 "github.com/ydb-platform/ydb-go-sdk/v3/trace"
2526)
@@ -40,33 +41,13 @@ type Balancer struct {
4041 discoveryRepeater repeater.Repeater
4142 localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
4243
43- mu xsync.RWMutex
44- connectionsState * connectionsState [conn.Conn ]
44+ connections atomic.Pointer [connections [conn.Conn ]]
4545
4646 closed chan struct {}
4747
4848 onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
4949}
5050
51- func (b * Balancer ) HasNode (id uint32 ) bool {
52- if b .config .SingleConn {
53- return true
54- }
55- b .mu .RLock ()
56- defer b .mu .RUnlock ()
57- if _ , has := b .connectionsState .connByNodeID [id ]; has {
58- return true
59- }
60-
61- return false
62- }
63-
64- func (b * Balancer ) OnUpdate (onApplyDiscoveredEndpoints func (ctx context.Context , endpoints []endpoint.Info )) {
65- b .mu .WithLock (func () {
66- b .onApplyDiscoveredEndpoints = append (b .onApplyDiscoveredEndpoints , onApplyDiscoveredEndpoints )
67- })
68- }
69-
7051func (b * Balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
7152 return retry .Retry (
7253 repeater .WithEvent (ctx , repeater .EventInit ),
@@ -135,37 +116,37 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
135116 return nil
136117}
137118
138- func endpointsDiff (newestEndpoints []endpoint. Endpoint , previousConns []conn. Info ) (
119+ func endpointsDiff (newestEndpoints []trace. EndpointInfo , previousEndpoints []trace. EndpointInfo ) (
139120 nodes []trace.EndpointInfo ,
140121 added []trace.EndpointInfo ,
141122 dropped []trace.EndpointInfo ,
142123) {
143124 nodes = make ([]trace.EndpointInfo , 0 , len (newestEndpoints ))
144- added = make ([]trace.EndpointInfo , 0 , len (previousConns ))
145- dropped = make ([]trace.EndpointInfo , 0 , len (previousConns ))
125+ added = make ([]trace.EndpointInfo , 0 , len (previousEndpoints ))
126+ dropped = make ([]trace.EndpointInfo , 0 , len (previousEndpoints ))
146127 var (
147128 newestMap = make (map [string ]struct {}, len (newestEndpoints ))
148- previousMap = make (map [string ]struct {}, len (previousConns ))
129+ previousMap = make (map [string ]struct {}, len (previousEndpoints ))
149130 )
150131 sort .Slice (newestEndpoints , func (i , j int ) bool {
151132 return newestEndpoints [i ].Address () < newestEndpoints [j ].Address ()
152133 })
153- sort .Slice (previousConns , func (i , j int ) bool {
154- return previousConns [i ].Endpoint (). Address () < previousConns [j ]. Endpoint () .Address ()
134+ sort .Slice (previousEndpoints , func (i , j int ) bool {
135+ return previousEndpoints [i ].Address () < previousEndpoints [j ].Address ()
155136 })
156- for _ , e := range previousConns {
157- previousMap [e .Endpoint (). Address ()] = struct {}{}
137+ for _ , e := range previousEndpoints {
138+ previousMap [e .Address ()] = struct {}{}
158139 }
159- for _ , e := range newestEndpoints {
160- nodes = append (nodes , e . Copy () )
161- newestMap [e .Address ()] = struct {}{}
162- if _ , has := previousMap [e .Address ()]; ! has {
163- added = append (added , e . Copy () )
140+ for _ , info := range newestEndpoints {
141+ nodes = append (nodes , info )
142+ newestMap [info .Address ()] = struct {}{}
143+ if _ , has := previousMap [info .Address ()]; ! has {
144+ added = append (added , info )
164145 }
165146 }
166- for _ , c := range previousConns {
167- if _ , has := newestMap [c . Endpoint () .Address ()]; ! has {
168- dropped = append (dropped , c . Endpoint (). Copy () )
147+ for _ , info := range previousEndpoints {
148+ if _ , has := newestMap [info .Address ()]; ! has {
149+ dropped = append (dropped , info )
169150 }
170151 }
171152
@@ -180,41 +161,28 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
180161 "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
181162 b .config .DetectLocalDC ,
182163 )
183- previousConns []conn.Info
184164 )
185- defer func () {
186- nodes , added , dropped := endpointsDiff (endpoints , previousConns )
187- onDone (nodes , added , dropped , localDC )
188- }()
189165
190166 connections := endpointsToConnections (b .pool , endpoints )
191167 for _ , c := range connections {
192- if c .State () == conn .Banned {
193- b .pool .Unban (ctx , c )
194- }
195168 c .Endpoint ().Touch ()
196169 }
197170
198171 info := balancerConfig.Info {SelfLocation : localDC }
199- state := newConnectionsState (connections , b .config .Filter , info , b .config .AllowFallback )
200-
201- endpointsInfo := make ([]endpoint.Info , len (endpoints ))
202- for i , e := range endpoints {
203- endpointsInfo [i ] = e
204- }
205-
206- b .mu .WithLock (func () {
207- if b .connectionsState != nil {
208- previousConns = make ([]conn.Info , len (b .connectionsState .all ))
209- for i := range b .connectionsState .all {
210- previousConns [i ] = b .connectionsState .all [i ]
211- }
212- }
213- b .connectionsState = state
214- for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
215- onApplyDiscoveredEndpoints (ctx , endpointsInfo )
172+ newestConnections := newConns (connections , b .config .Filter , info , b .config .AllowFallback )
173+ previousConnections := b .connections .Swap (newestConnections )
174+ defer func () {
175+ if previousConnections != nil {
176+ nodes , added , dropped := endpointsDiff (newestConnections .all .ToTraceEndpointInfo (), previousConnections .all .ToTraceEndpointInfo ())
177+ onDone (nodes , added , dropped , localDC )
178+ } else {
179+ nodes , added , dropped := endpointsDiff (newestConnections .all .ToTraceEndpointInfo (), nil )
180+ onDone (nodes , added , dropped , localDC )
216181 }
217- })
182+ }()
183+ for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
184+ onApplyDiscoveredEndpoints (ctx , newestConnections .all .ToEndpointInfo ())
185+ }
218186}
219187
220188func (b * Balancer ) Close (ctx context.Context ) (err error ) {
@@ -241,6 +209,44 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
241209 return nil
242210}
243211
212+ func (b * Balancer ) markConnAsBad (ctx context.Context , cc conn.Conn , cause error ) {
213+ onDone := trace .DriverOnBalancerMarkConnAsBad (
214+ b .driverConfig .Trace (), & ctx ,
215+ stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).markConnAsBad" ),
216+ cc .Endpoint (), cause ,
217+ )
218+
219+ if ! xerrors .IsTransportError (cause ,
220+ grpcCodes .ResourceExhausted ,
221+ grpcCodes .Unavailable ,
222+ // grpcCodes.OK,
223+ // grpcCodes.Canceled,
224+ // grpcCodes.Unknown,
225+ // grpcCodes.InvalidArgument,
226+ // grpcCodes.DeadlineExceeded,
227+ // grpcCodes.NotFound,
228+ // grpcCodes.AlreadyExists,
229+ // grpcCodes.PermissionDenied,
230+ // grpcCodes.FailedPrecondition,
231+ // grpcCodes.Aborted,
232+ // grpcCodes.OutOfRange,
233+ // grpcCodes.Unimplemented,
234+ // grpcCodes.Internal,
235+ // grpcCodes.DataLoss,
236+ // grpcCodes.Unauthenticated,
237+ ) {
238+ return
239+ }
240+
241+ newestConns , changed := b .connections .Load ().withBadConn (cc )
242+
243+ if changed {
244+ b .connections .Store (newestConns )
245+ }
246+
247+ onDone (newestConns .prefer .ToTraceEndpointInfo (), newestConns .fallback .ToTraceEndpointInfo ())
248+ }
249+
244250func New (
245251 ctx context.Context ,
246252 driverConfig * config.Config ,
@@ -353,10 +359,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
353359 }
354360
355361 defer func () {
356- if err == nil {
357- b .pool .Unban (ctx , cc )
358- } else if xerrors .MustBanConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
359- b .pool .Ban (ctx , cc , err )
362+ if err != nil {
363+ b .markConnAsBad (ctx , cc , err )
360364 }
361365 }()
362366
@@ -383,13 +387,6 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
383387 return nil
384388}
385389
386- func (b * Balancer ) connections () * connectionsState [conn.Conn ] {
387- b .mu .RLock ()
388- defer b .mu .RUnlock ()
389-
390- return b .connectionsState
391- }
392-
393390func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
394391 onDone := trace .DriverOnBalancerChooseEndpoint (
395392 b .driverConfig .Trace (), & ctx ,
@@ -408,17 +405,17 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
408405 }
409406
410407 var (
411- state = b .connections ()
408+ connections = b .connections . Load ()
412409 failedCount int
413410 )
414411
415412 defer func () {
416- if failedCount * 2 > state .PreferredCount () && b .discoveryRepeater != nil {
413+ if failedCount * 2 > connections .PreferredCount () && b .discoveryRepeater != nil {
417414 b .discoveryRepeater .Force ()
418415 }
419416 }()
420417
421- c , failedCount = state . GetConnection (ctx )
418+ c , failedCount = connections . GetConn (ctx )
422419 if c == nil {
423420 return nil , xerrors .WithStackTrace (
424421 fmt .Errorf ("cannot get connection from Balancer after %d attempts: %w" , failedCount , ErrNoEndpoints ),
@@ -429,10 +426,10 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
429426}
430427
431428func endpointsToConnections (p * conn.Pool , endpoints []endpoint.Endpoint ) []conn.Conn {
432- conns := make ([]conn.Conn , 0 , len (endpoints ))
429+ connections := make ([]conn.Conn , 0 , len (endpoints ))
433430 for _ , e := range endpoints {
434- conns = append (conns , p .Get (e ))
431+ connections = append (connections , p .Get (e ))
435432 }
436433
437- return conns
434+ return connections
438435}
0 commit comments