From debceaf3c78350ab37b8cbc6bcf3b3470cb60b8d Mon Sep 17 00:00:00 2001
From: JmPotato <ghzpotato@gmail.com>
Date: Mon, 27 Jan 2025 20:49:59 +0800
Subject: [PATCH] client/router: implement the query region gRPC client (#8939)

ref tikv/pd#8690

Implement the router stream update logic.

Signed-off-by: JmPotato <ghzpotato@gmail.com>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
---
 client/client.go                         |  46 +--
 client/clients/router/client.go          | 434 +++++++++++++++++++++++
 client/clients/router/request.go         | 114 ++++++
 client/go.mod                            |   2 +-
 client/go.sum                            |   4 +-
 client/inner_client.go                   |  10 +
 client/pkg/connectionctx/manager.go      |   3 +-
 tests/integrations/client/client_test.go |  83 ++++-
 8 files changed, 670 insertions(+), 26 deletions(-)
 create mode 100644 client/clients/router/request.go

diff --git a/client/client.go b/client/client.go
index fa0a1473ba7..e5f6442780b 100644
--- a/client/client.go
+++ b/client/client.go
@@ -143,12 +143,11 @@ var _ Client = (*client)(nil)
 
 // serviceModeKeeper is for service mode switching.
 type serviceModeKeeper struct {
-	// RMutex here is for the future usage that there might be multiple goroutines
-	// triggering service mode switching concurrently.
 	sync.RWMutex
 	serviceMode     pdpb.ServiceMode
 	tsoClient       *tso.Cli
 	tsoSvcDiscovery sd.ServiceDiscovery
+	routerClient    *router.Cli
 }
 
 func (k *serviceModeKeeper) close() {
@@ -570,21 +569,16 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e
 	return minTS.Physical, minTS.Logical, nil
 }
 
-func handleRegionResponse(res *pdpb.GetRegionResponse) *router.Region {
-	if res.Region == nil {
-		return nil
-	}
+// EnableRouterClient enables the router client.
+// This is only for test currently.
+func (c *client) EnableRouterClient() {
+	c.inner.initRouterClient()
+}
 
-	r := &router.Region{
-		Meta:         res.Region,
-		Leader:       res.Leader,
-		PendingPeers: res.PendingPeers,
-		Buckets:      res.Buckets,
-	}
-	for _, s := range res.DownPeers {
-		r.DownPeers = append(r.DownPeers, s.Peer)
-	}
-	return r
+func (c *client) getRouterClient() *router.Cli {
+	c.inner.RLock()
+	defer c.inner.RUnlock()
+	return c.inner.routerClient
 }
 
 // GetRegionFromMember implements the RPCClient interface.
@@ -623,7 +617,7 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs
 		errorMsg := fmt.Sprintf("[pd] can't get region info from member URLs: %+v", memberURLs)
 		return nil, errors.WithStack(errors.New(errorMsg))
 	}
-	return handleRegionResponse(resp), nil
+	return router.ConvertToRegion(resp), nil
 }
 
 // GetRegion implements the RPCClient interface.
@@ -637,6 +631,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio
 	ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
 	defer cancel()
 
+	if routerClient := c.getRouterClient(); routerClient != nil {
+		return routerClient.GetRegion(ctx, key, opts...)
+	}
+
 	options := &opt.GetRegionOp{}
 	for _, opt := range opts {
 		opt(options)
@@ -663,7 +661,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio
 	if err = c.respForErr(metrics.CmdFailedDurationGetRegion, start, err, resp.GetHeader()); err != nil {
 		return nil, err
 	}
-	return handleRegionResponse(resp), nil
+	return router.ConvertToRegion(resp), nil
 }
 
 // GetPrevRegion implements the RPCClient interface.
@@ -677,6 +675,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR
 	ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
 	defer cancel()
 
+	if routerClient := c.getRouterClient(); routerClient != nil {
+		return routerClient.GetPrevRegion(ctx, key, opts...)
+	}
+
 	options := &opt.GetRegionOp{}
 	for _, opt := range opts {
 		opt(options)
@@ -703,7 +705,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR
 	if err = c.respForErr(metrics.CmdFailedDurationGetPrevRegion, start, err, resp.GetHeader()); err != nil {
 		return nil, err
 	}
-	return handleRegionResponse(resp), nil
+	return router.ConvertToRegion(resp), nil
 }
 
 // GetRegionByID implements the RPCClient interface.
@@ -717,6 +719,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt
 	ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
 	defer cancel()
 
+	if routerClient := c.getRouterClient(); routerClient != nil {
+		return routerClient.GetRegionByID(ctx, regionID, opts...)
+	}
+
 	options := &opt.GetRegionOp{}
 	for _, opt := range opts {
 		opt(options)
@@ -744,7 +750,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt
 	if err = c.respForErr(metrics.CmdFailedDurationGetRegionByID, start, err, resp.GetHeader()); err != nil {
 		return nil, err
 	}
-	return handleRegionResponse(resp), nil
+	return router.ConvertToRegion(resp), nil
 }
 
 // ScanRegions implements the RPCClient interface.
diff --git a/client/clients/router/client.go b/client/clients/router/client.go
index 48cebfa950e..0038c42dff8 100644
--- a/client/clients/router/client.go
+++ b/client/clients/router/client.go
@@ -18,12 +18,30 @@ import (
 	"context"
 	"encoding/hex"
 	"net/url"
+	"runtime/trace"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/opentracing/opentracing-go"
+	"go.uber.org/zap"
+	"google.golang.org/grpc"
 
 	"github.com/pingcap/kvproto/pkg/metapb"
+	"github.com/pingcap/kvproto/pkg/pdpb"
+	"github.com/pingcap/log"
 
+	"github.com/tikv/pd/client/errs"
 	"github.com/tikv/pd/client/opt"
+	"github.com/tikv/pd/client/pkg/batch"
+	cctx "github.com/tikv/pd/client/pkg/connectionctx"
+	"github.com/tikv/pd/client/pkg/retry"
+	sd "github.com/tikv/pd/client/servicediscovery"
 )
 
+// defaultMaxRouterRequestBatchSize is the default max size of the router request batch.
+const defaultMaxRouterRequestBatchSize = 10000
+
 // Region contains information of a region's meta and its peers.
 type Region struct {
 	Meta         *metapb.Region
@@ -33,6 +51,33 @@ type Region struct {
 	Buckets      *metapb.Buckets
 }
 
+type regionResponse interface {
+	GetRegion() *metapb.Region
+	GetLeader() *metapb.Peer
+	GetDownPeers() []*pdpb.PeerStats
+	GetPendingPeers() []*metapb.Peer
+	GetBuckets() *metapb.Buckets
+}
+
+// ConvertToRegion converts the region response to the region.
+func ConvertToRegion(res regionResponse) *Region {
+	region := res.GetRegion()
+	if region == nil {
+		return nil
+	}
+
+	r := &Region{
+		Meta:         region,
+		Leader:       res.GetLeader(),
+		PendingPeers: res.GetPendingPeers(),
+		Buckets:      res.GetBuckets(),
+	}
+	for _, s := range res.GetDownPeers() {
+		r.DownPeers = append(r.DownPeers, s.Peer)
+	}
+	return r
+}
+
 // KeyRange defines a range of keys in bytes.
 type KeyRange struct {
 	StartKey []byte
@@ -92,3 +137,392 @@ type Client interface {
 	// The returned regions are flattened, even there are key ranges located in the same region, only one region will be returned.
 	BatchScanRegions(ctx context.Context, keyRanges []KeyRange, limit int, opts ...opt.GetRegionOption) ([]*Region, error)
 }
+
+// Cli is the implementation of the router client.
+type Cli struct {
+	ctx    context.Context
+	cancel context.CancelFunc
+	wg     sync.WaitGroup
+	option *opt.Option
+
+	svcDiscovery sd.ServiceDiscovery
+	// leaderURL is the URL of the router leader.
+	leaderURL atomic.Value
+	// conCtxMgr is used to store the context of the router stream connection(s).
+	conCtxMgr *cctx.Manager[pdpb.PD_QueryRegionClient]
+	// updateConnectionCh is used to trigger the connection update actively.
+	updateConnectionCh chan struct{}
+	// bo is the backoffer for the router client.
+	bo *retry.Backoffer
+
+	reqPool         *sync.Pool
+	requestCh       chan *Request
+	batchController *batch.Controller[*Request]
+}
+
+// NewClient returns a new router client.
+func NewClient(
+	ctx context.Context,
+	svcDiscovery sd.ServiceDiscovery,
+	option *opt.Option,
+) *Cli {
+	ctx, cancel := context.WithCancel(ctx)
+	c := &Cli{
+		ctx:                ctx,
+		cancel:             cancel,
+		svcDiscovery:       svcDiscovery,
+		option:             option,
+		conCtxMgr:          cctx.NewManager[pdpb.PD_QueryRegionClient](),
+		updateConnectionCh: make(chan struct{}, 1),
+		bo: retry.InitialBackoffer(
+			sd.UpdateMemberBackOffBaseTime,
+			sd.UpdateMemberMaxBackoffTime,
+			sd.UpdateMemberTimeout,
+		),
+		reqPool: &sync.Pool{
+			New: func() any {
+				return &Request{
+					done: make(chan error, 1),
+				}
+			},
+		},
+		requestCh:       make(chan *Request, defaultMaxRouterRequestBatchSize*2),
+		batchController: batch.NewController(defaultMaxRouterRequestBatchSize, requestFinisher(nil), nil),
+	}
+	c.leaderURL.Store(svcDiscovery.GetServingURL())
+	c.svcDiscovery.ExecAndAddLeaderSwitchedCallback(c.updateLeaderURL)
+	c.svcDiscovery.AddMembersChangedCallback(c.scheduleUpdateConnection)
+
+	c.wg.Add(2)
+	go c.connectionDaemon()
+	go c.dispatcher()
+
+	return c
+}
+
+func (c *Cli) newRequest(ctx context.Context) *Request {
+	req := c.reqPool.Get().(*Request)
+	req.requestCtx = ctx
+	req.clientCtx = c.ctx
+	req.pool = c.reqPool
+
+	return req
+}
+
+func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request] {
+	var keyIdx, prevKeyIdx int
+	return func(_ int, req *Request, err error) {
+		requestCtx := req.requestCtx
+		defer trace.StartRegion(requestCtx, "pdclient.regionReqDone").End()
+
+		if err != nil {
+			req.tryDone(err)
+			return
+		}
+
+		var id uint64
+		if req.key != nil {
+			id = resp.KeyIdMap[keyIdx]
+			keyIdx++
+		} else if req.prevKey != nil {
+			id = resp.PrevKeyIdMap[prevKeyIdx]
+			prevKeyIdx++
+		} else if req.id != 0 {
+			id = req.id
+		}
+		if region, ok := resp.RegionsById[id]; ok {
+			req.region = ConvertToRegion(region)
+		}
+		req.tryDone(err)
+	}
+}
+
+func (c *Cli) cancelCollectedRequests(err error) {
+	c.batchController.FinishCollectedRequests(requestFinisher(nil), err)
+}
+
+func (c *Cli) doneCollectedRequests(resp *pdpb.QueryRegionResponse) {
+	c.batchController.FinishCollectedRequests(requestFinisher(resp), nil)
+}
+
+// Close closes the router client.
+func (c *Cli) Close() {
+	if c == nil {
+		return
+	}
+	log.Info("[router] closing router client")
+
+	c.cancel()
+	c.wg.Wait()
+
+	log.Info("[router] router client is closed")
+}
+
+func (c *Cli) getLeaderURL() string {
+	url := c.leaderURL.Load()
+	if url == nil {
+		return ""
+	}
+	return url.(string)
+}
+
+func (c *Cli) updateLeaderURL(url string) error {
+	oldURL := c.getLeaderURL()
+	if oldURL == url {
+		return nil
+	}
+	c.leaderURL.Store(url)
+	c.scheduleUpdateConnection()
+
+	log.Info("[router] switch the router leader serving url",
+		zap.String("old-url", oldURL), zap.String("new-url", url))
+	return nil
+}
+
+// getLeaderClientConn returns the leader gRPC client connection.
+func (c *Cli) getLeaderClientConn() (*grpc.ClientConn, string) {
+	url := c.getLeaderURL()
+	if len(url) == 0 {
+		c.svcDiscovery.ScheduleCheckMemberChanged()
+		return nil, ""
+	}
+	cc, ok := c.svcDiscovery.GetClientConns().Load(url)
+	if !ok {
+		return nil, url
+	}
+	return cc.(*grpc.ClientConn), url
+}
+
+// scheduleUpdateConnection is used to schedule an update to the connection(s).
+func (c *Cli) scheduleUpdateConnection() {
+	select {
+	case c.updateConnectionCh <- struct{}{}:
+	default:
+	}
+}
+
+// connectionDaemon is used to update the router leader/primary/backup connection(s) in background.
+// It aims to provide a seamless connection updating for the router client to keep providing the
+// router service without interruption.
+func (c *Cli) connectionDaemon() {
+	defer c.wg.Done()
+	updaterCtx, updaterCancel := context.WithCancel(c.ctx)
+	defer updaterCancel()
+	updateTicker := time.NewTicker(sd.MemberUpdateInterval)
+	defer updateTicker.Stop()
+
+	log.Info("[router] connection daemon is started")
+	for {
+		c.updateConnection(updaterCtx)
+		select {
+		case <-updaterCtx.Done():
+			log.Info("[router] connection daemon is exiting")
+			return
+		case <-updateTicker.C:
+		case <-c.updateConnectionCh:
+		}
+	}
+}
+
+// updateConnection is used to get the leader client connection and update the connection context if it does not exist before.
+func (c *Cli) updateConnection(ctx context.Context) {
+	cc, url := c.getLeaderClientConn()
+	if cc == nil || len(url) == 0 {
+		log.Warn("[router] got an invalid leader client connection", zap.String("url", url))
+		return
+	}
+	if c.conCtxMgr.Exist(url) {
+		log.Debug("[router] the router leader remains unchanged", zap.String("url", url))
+		return
+	}
+	stream, err := pdpb.NewPDClient(cc).QueryRegion(ctx)
+	if err != nil {
+		log.Error("[router] failed to create the router stream connection", errs.ZapError(err))
+	}
+	c.conCtxMgr.Store(ctx, url, stream)
+	// TODO: support the forwarding mechanism for the router client.
+	// TODO: support sending the router requests to the follower nodes.
+}
+
+func (c *Cli) dispatcher() {
+	defer c.wg.Done()
+
+	var (
+		stream            pdpb.PD_QueryRegionClient
+		streamURL         string
+		streamCtx         context.Context
+		timeoutTimer      *time.Timer
+		resetTimeoutTimer = func() {
+			if timeoutTimer == nil {
+				timeoutTimer = time.NewTimer(c.option.Timeout)
+			} else {
+				timeoutTimer.Reset(c.option.Timeout)
+			}
+		}
+		ctx, cancel = context.WithCancel(c.ctx)
+	)
+
+	log.Info("[router] dispatcher is started")
+	defer func() {
+		log.Info("[router] dispatcher is exiting")
+		cancel()
+		if timeoutTimer != nil {
+			timeoutTimer.Stop()
+		}
+		log.Info("[router] dispatcher exited")
+	}()
+batchLoop:
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
+		// Step 1: Fetch the pending router requests in batch.
+		err := c.batchController.FetchPendingRequests(ctx, c.requestCh, nil, 0)
+		if err != nil {
+			if err == context.Canceled {
+				log.Info("[router] stop fetching the pending router requests due to context canceled")
+			} else {
+				log.Error("[router] failed to fetch the pending router requests", errs.ZapError(err))
+			}
+			return
+		}
+
+		// Step 2: Choose a stream connection to send the router request.
+		resetTimeoutTimer()
+	connectionCtxChoosingLoop:
+		for {
+			// Check if the dispatcher is canceled or the timeout timer is triggered.
+			select {
+			case <-ctx.Done():
+				return
+			case <-timeoutTimer.C:
+				log.Error("[router] router stream connection is not ready until timeout, abort the batch")
+				c.svcDiscovery.ScheduleCheckMemberChanged()
+				c.batchController.FinishCollectedRequests(requestFinisher(nil), err)
+				continue batchLoop
+			default:
+			}
+			// Choose a stream connection to send the router request later.
+			connectionCtx := c.conCtxMgr.GetConnectionCtx()
+			if connectionCtx == nil {
+				log.Info("[router] router stream connection is not ready")
+				c.updateConnection(ctx)
+				continue connectionCtxChoosingLoop
+			}
+			streamCtx, streamURL, stream = connectionCtx.Ctx, connectionCtx.StreamURL, connectionCtx.Stream
+			// Check if the stream connection is canceled.
+			select {
+			case <-streamCtx.Done():
+				log.Info("[router] router stream connection is canceled", zap.String("stream-url", streamURL))
+				c.conCtxMgr.Release(streamURL)
+				continue connectionCtxChoosingLoop
+			default:
+			}
+			// The stream connection is ready, break the loop.
+			break connectionCtxChoosingLoop
+		}
+
+		// Step 3: Dispatch the router requests to the stream connection.
+		// TODO: timeout handling if the stream takes too long to process the requests.
+		err = c.processRequests(stream)
+		if err != nil {
+			if !c.handleProcessRequestError(ctx, streamURL, err) {
+				return
+			}
+		}
+	}
+}
+
+func (c *Cli) processRequests(stream pdpb.PD_QueryRegionClient) error {
+	var (
+		requests     = c.batchController.GetCollectedRequests()
+		traceRegions = make([]*trace.Region, 0, len(requests))
+		spans        = make([]opentracing.Span, 0, len(requests))
+	)
+	for _, req := range requests {
+		traceRegions = append(traceRegions, trace.StartRegion(req.requestCtx, "pdclient.regionReqSend"))
+		if span := opentracing.SpanFromContext(req.requestCtx); span != nil && span.Tracer() != nil {
+			spans = append(spans, span.Tracer().StartSpan("pdclient.processRegionRequests", opentracing.ChildOf(span.Context())))
+		}
+	}
+	defer func() {
+		for i := range spans {
+			spans[i].Finish()
+		}
+		for i := range traceRegions {
+			traceRegions[i].End()
+		}
+	}()
+
+	queryReq := &pdpb.QueryRegionRequest{
+		Header: &pdpb.RequestHeader{
+			ClusterId: c.svcDiscovery.GetClusterID(),
+		},
+		Keys:     make([][]byte, 0, len(requests)),
+		PrevKeys: make([][]byte, 0, len(requests)),
+		Ids:      make([]uint64, 0, len(requests)),
+	}
+	for _, req := range requests {
+		if !queryReq.NeedBuckets && req.needBuckets {
+			queryReq.NeedBuckets = true
+		}
+		if req.key != nil {
+			queryReq.Keys = append(queryReq.Keys, req.key)
+		} else if req.prevKey != nil {
+			queryReq.PrevKeys = append(queryReq.PrevKeys, req.prevKey)
+		} else if req.id != 0 {
+			queryReq.Ids = append(queryReq.Ids, req.id)
+		} else {
+			panic("invalid region query request received")
+		}
+	}
+	err := stream.Send(queryReq)
+	if err != nil {
+		return err
+	}
+	resp, err := stream.Recv()
+	if err != nil {
+		return err
+	}
+	c.doneCollectedRequests(resp)
+	return nil
+}
+
+func (c *Cli) handleProcessRequestError(
+	ctx context.Context,
+	streamURL string,
+	err error,
+) bool {
+	log.Error("[router] failed to process the router requests",
+		zap.String("stream-url", streamURL),
+		errs.ZapError(err))
+	c.cancelCollectedRequests(err)
+
+	select {
+	case <-ctx.Done():
+		return false
+	default:
+	}
+
+	// Delete the stream connection context.
+	c.conCtxMgr.Release(streamURL)
+	if errs.IsLeaderChange(err) {
+		// If the leader changes, we better call `CheckMemberChanged` blockingly to
+		// ensure the next round of router requests can be sent to the new leader.
+		if err := c.bo.Exec(ctx, c.svcDiscovery.CheckMemberChanged); err != nil {
+			select {
+			case <-ctx.Done():
+				return false
+			default:
+			}
+		}
+	} else {
+		// For other errors, we can just schedule a member change check asynchronously.
+		c.svcDiscovery.ScheduleCheckMemberChanged()
+	}
+
+	return true
+}
diff --git a/client/clients/router/request.go b/client/clients/router/request.go
new file mode 100644
index 00000000000..4578514597d
--- /dev/null
+++ b/client/clients/router/request.go
@@ -0,0 +1,114 @@
+// Copyright 2024 TiKV Project Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package router
+
+import (
+	"context"
+	"runtime/trace"
+	"sync"
+
+	"github.com/pingcap/errors"
+
+	"github.com/tikv/pd/client/opt"
+)
+
+// Request is a region info request.
+type Request struct {
+	requestCtx context.Context
+	clientCtx  context.Context
+
+	// Key field represents this is a `GetRegion` request.
+	key []byte
+	// PrevKey field represents this is a `GetPrevRegion` request.
+	prevKey []byte
+	// ID field represents this is a `GetRegionByID` request.
+	id uint64
+
+	// NeedBuckets field represents whether the request needs to get the region buckets.
+	needBuckets bool
+
+	done chan error
+	// region will be set after the request is done.
+	region *Region
+
+	// Runtime fields.
+	pool *sync.Pool
+}
+
+func (req *Request) tryDone(err error) {
+	select {
+	case req.done <- err:
+	default:
+	}
+}
+
+func (req *Request) wait() (*Region, error) {
+	// TODO: introduce the metrics.
+	select {
+	case err := <-req.done:
+		defer req.pool.Put(req)
+		defer trace.StartRegion(req.requestCtx, "pdclient.regionReqDone").End()
+		if err != nil {
+			return nil, errors.WithStack(err)
+		}
+		return req.region, nil
+	case <-req.requestCtx.Done():
+		return nil, errors.WithStack(req.requestCtx.Err())
+	case <-req.clientCtx.Done():
+		return nil, errors.WithStack(req.clientCtx.Err())
+	}
+}
+
+// GetRegion implements the Client interface.
+func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) {
+	req := c.newRequest(ctx)
+	req.key = key
+	options := &opt.GetRegionOp{}
+	for _, opt := range opts {
+		opt(options)
+	}
+	req.needBuckets = options.NeedBuckets
+
+	c.requestCh <- req
+	return req.wait()
+}
+
+// GetPrevRegion implements the Client interface.
+func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) {
+	req := c.newRequest(ctx)
+	req.prevKey = key
+	options := &opt.GetRegionOp{}
+	for _, opt := range opts {
+		opt(options)
+	}
+	req.needBuckets = options.NeedBuckets
+
+	c.requestCh <- req
+	return req.wait()
+}
+
+// GetRegionByID implements the Client interface.
+func (c *Cli) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt.GetRegionOption) (*Region, error) {
+	req := c.newRequest(ctx)
+	req.id = regionID
+	options := &opt.GetRegionOp{}
+	for _, opt := range opts {
+		opt(options)
+	}
+	req.needBuckets = options.NeedBuckets
+
+	c.requestCh <- req
+	return req.wait()
+}
diff --git a/client/go.mod b/client/go.mod
index 78aef084ff7..a84bf303be1 100644
--- a/client/go.mod
+++ b/client/go.mod
@@ -10,7 +10,7 @@ require (
 	github.com/opentracing/opentracing-go v1.2.0
 	github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c
 	github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86
-	github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037
+	github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1
 	github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3
 	github.com/prometheus/client_golang v1.20.5
 	github.com/stretchr/testify v1.9.0
diff --git a/client/go.sum b/client/go.sum
index 4cca5ba3ad5..2873e4f550c 100644
--- a/client/go.sum
+++ b/client/go.sum
@@ -49,8 +49,8 @@ github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTm
 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg=
 github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 h1:tdMsjOqUR7YXHoBitzdebTvOjs/swniBTOLy5XiMtuE=
 github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86/go.mod h1:exzhVYca3WRtd6gclGNErRWb1qEgff3LYta0LvRmON4=
-github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 h1:xYNSJjYNur4Dr5bV+9BXK9n5E0T1zlcAN25XX68+mOg=
-github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8=
+github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 h1:rTAyiswGyWSGHJVa4Mkhdi8YfGqfA4LrUVKsH9nrJ8E=
+github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8=
 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8IDP+SZrdhV1Kibl9KrHxJ9eciw=
 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4=
 github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
diff --git a/client/inner_client.go b/client/inner_client.go
index 7ce8f42386e..181ee2c9d52 100644
--- a/client/inner_client.go
+++ b/client/inner_client.go
@@ -27,6 +27,7 @@ import (
 	"github.com/pingcap/kvproto/pkg/pdpb"
 	"github.com/pingcap/log"
 
+	"github.com/tikv/pd/client/clients/router"
 	"github.com/tikv/pd/client/clients/tso"
 	"github.com/tikv/pd/client/errs"
 	"github.com/tikv/pd/client/metrics"
@@ -73,6 +74,15 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error {
 	return nil
 }
 
+func (c *innerClient) initRouterClient() {
+	c.Lock()
+	defer c.Unlock()
+	if c.routerClient != nil {
+		return
+	}
+	c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option)
+}
+
 func (c *innerClient) setServiceMode(newMode pdpb.ServiceMode) {
 	c.Lock()
 	defer c.Unlock()
diff --git a/client/pkg/connectionctx/manager.go b/client/pkg/connectionctx/manager.go
index 04c1eb13d3a..fede8baf723 100644
--- a/client/pkg/connectionctx/manager.go
+++ b/client/pkg/connectionctx/manager.go
@@ -16,9 +16,8 @@ package connectionctx
 
 import (
 	"context"
+	"math/rand"
 	"sync"
-
-	"golang.org/x/exp/rand"
 )
 
 type connectionCtx[T any] struct {
diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go
index af8cdc00a7e..0f972cb2b8f 100644
--- a/tests/integrations/client/client_test.go
+++ b/tests/integrations/client/client_test.go
@@ -20,6 +20,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"math"
+	"math/rand"
 	"os"
 	"path"
 	"reflect"
@@ -1105,6 +1106,10 @@ func bootstrapServer(re *require.Assertions, header *pdpb.RequestHeader, client
 	re.Equal(pdpb.ErrorType_OK, resp.GetHeader().GetError().GetType())
 }
 
+func (suite *clientTestSuite) SetupTest() {
+	suite.grpcSvr.DirectlyGetRaftCluster().ResetRegionCache()
+}
+
 func (suite *clientTestSuite) TestGetRegion() {
 	re := suite.Require()
 	regionID := regionIDAllocator.alloc()
@@ -1204,7 +1209,6 @@ func (suite *clientTestSuite) TestGetPrevRegion() {
 		err := suite.regionHeartbeat.Send(req)
 		re.NoError(err)
 	}
-	time.Sleep(500 * time.Millisecond)
 	for i := range 20 {
 		testutil.Eventually(re, func() bool {
 			r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)})
@@ -1338,6 +1342,83 @@ func (suite *clientTestSuite) TestGetRegionByID() {
 	})
 }
 
+func (suite *clientTestSuite) TestGetRegionConcurrently() {
+	suite.client.(interface{ EnableRouterClient() }).EnableRouterClient()
+
+	re := suite.Require()
+	ctx, cancel := context.WithCancel(suite.ctx)
+	defer cancel()
+
+	regions := make([]*metapb.Region, 0, 2)
+	for i := range 2 {
+		regionID := regionIDAllocator.alloc()
+		region := &metapb.Region{
+			Id: regionID,
+			RegionEpoch: &metapb.RegionEpoch{
+				ConfVer: 1,
+				Version: 1,
+			},
+			StartKey: []byte{byte(i)},
+			EndKey:   []byte{byte(i + 1)},
+			Peers:    peers,
+		}
+		re.NoError(suite.regionHeartbeat.Send(&pdpb.RegionHeartbeatRequest{
+			Header: newHeader(),
+			Region: region,
+			Leader: peers[0],
+		}))
+		regions = append(regions, region)
+	}
+
+	const concurrency = 1000
+
+	wg := sync.WaitGroup{}
+	wg.Add(concurrency)
+	for range concurrency {
+		go func() {
+			defer wg.Done()
+			switch rand.Intn(3) {
+			case 0:
+				region := regions[0]
+				testutil.Eventually(re, func() bool {
+					r, err := suite.client.GetRegion(ctx, region.GetStartKey())
+					re.NoError(err)
+					if r == nil {
+						return false
+					}
+					return reflect.DeepEqual(region, r.Meta) &&
+						reflect.DeepEqual(peers[0], r.Leader) &&
+						r.Buckets == nil
+				})
+			case 1:
+				testutil.Eventually(re, func() bool {
+					r, err := suite.client.GetPrevRegion(ctx, regions[1].GetStartKey())
+					re.NoError(err)
+					if r == nil {
+						return false
+					}
+					return reflect.DeepEqual(regions[0], r.Meta) &&
+						reflect.DeepEqual(peers[0], r.Leader) &&
+						r.Buckets == nil
+				})
+			case 2:
+				region := regions[0]
+				testutil.Eventually(re, func() bool {
+					r, err := suite.client.GetRegionByID(ctx, region.GetId())
+					re.NoError(err)
+					if r == nil {
+						return false
+					}
+					return reflect.DeepEqual(region, r.Meta) &&
+						reflect.DeepEqual(peers[0], r.Leader) &&
+						r.Buckets == nil
+				})
+			}
+		}()
+	}
+	wg.Wait()
+}
+
 func (suite *clientTestSuite) TestGetStore() {
 	re := suite.Require()
 	cluster := suite.srv.GetRaftCluster()