@@ -52,13 +52,17 @@ struct Endpoint {
5252
5353 int inComingCount = 0 ;
5454 int kStartDepth = 128 ;
55- int kRxDepth = 256 ;
55+ int kRxDepth = 128 ;
5656 int kReplyDepth = kRxDepth ;
5757 int kMaxInlineSize ;
5858 WRContext *rx_ctx;
5959 WRContext *start_ctx;
6060 WRContext *reply_ctx;
6161
62+ struct ibv_mr *rx_ctx_mr = nullptr ;
63+ struct ibv_mr *start_ctx_mr = nullptr ;
64+ struct ibv_mr *reply_ctx_mr = nullptr ;
65+
6266 ThreadsafeQueue<WRContext *> free_start_ctx;
6367 ThreadsafeQueue<WRContext *> free_reply_ctx;
6468
@@ -94,27 +98,24 @@ struct Endpoint {
9498 }
9599
96100 ~Endpoint () {
97- for (int i = 0 ; i < kRxDepth ; ++i) {
98- if (!(rx_ctx[i].buffer )) {
99- continue ;
100- }
101- free (rx_ctx[i].buffer ->addr );
102- PS_CHECK_EQ (ibv_dereg_mr (rx_ctx[i].buffer ), 0 );
101+ if (rx_ctx_mr) {
102+ void *buf = rx_ctx_mr->addr ;
103+ PS_CHECK_EQ (ibv_dereg_mr (rx_ctx_mr), 0 );
104+ free (buf);
103105 }
104106
105- for (int i = 0 ; i < kStartDepth ; ++i) {
106- if (start_ctx[i].buffer ) {
107- free (start_ctx[i].buffer ->addr );
108- PS_CHECK_EQ (ibv_dereg_mr (start_ctx[i].buffer ), 0 );
109- }
107+ if (start_ctx_mr) {
108+ void *buf = start_ctx_mr->addr ;
109+ PS_CHECK_EQ (ibv_dereg_mr (start_ctx_mr), 0 );
110+ free (buf);
110111 }
111112
112- for (int i = 0 ; i < kReplyDepth ; ++i) {
113- if (reply_ctx[i].buffer ) {
114- free (reply_ctx[i].buffer ->addr );
115- PS_CHECK_EQ (ibv_dereg_mr (reply_ctx[i].buffer ), 0 );
116- }
113+ if (reply_ctx_mr) {
114+ void *buf = reply_ctx_mr->addr ;
115+ PS_CHECK_EQ (ibv_dereg_mr (reply_ctx_mr), 0 );
116+ free (buf);
117117 }
118+
118119 FOR_QPS {
119120 rdma_destroy_qp (cm_ids[qpIndex]);
120121 PS_CHECK_EQ (rdma_destroy_id (cm_ids[qpIndex]), 0 ) << strerror (errno);
@@ -150,24 +151,24 @@ struct Endpoint {
150151
151152 void SetNodeID (int id) { node_id = id; }
152153
153- void InitSendContextHelper (struct ibv_pd *pd, WRContext *ctx,
154- ThreadsafeQueue<WRContext *> *queue, size_t num,
155- WRContextType type) {
156- for (size_t i = 0 ; i < num; ++i) {
157- void *buf;
158- aligned_malloc (reinterpret_cast <void **>(&buf), kMempoolChunkSize );
159- PS_CHECK (buf);
160- struct ibv_mr *mr = ibv_reg_mr (pd, buf, kMempoolChunkSize , 0 );
161- PS_CHECK (mr)
162- << " ibv_reg_mr failed: " << strerror (errno)
163- << " \n You can try to reduce BYTEPS_RDMA_START_DEPTH (current "
164- << kStartDepth << " ) or BYTEPS_RDMA_RX_DEPTH (current " << kRxDepth
165- << " )." ;
154+ void InitWRContextHelper (struct ibv_pd *pd, WRContext *ctx,
155+ size_t num, WRContextType type,
156+ struct ibv_mr **mr, unsigned int access,
157+ ThreadsafeQueue<WRContext *> *queue = nullptr ) {
158+ char *buf;
159+ aligned_malloc (reinterpret_cast <void **>(&buf), kMempoolChunkSize * num);
160+ PS_CHECK (buf);
161+ *mr = ibv_reg_mr (pd, buf, kMempoolChunkSize * num, access);
162+ PS_CHECK (*mr) << " ibv_reg_mr failed: " << strerror (errno);
166163
164+ for (size_t i = 0 ; i < num; ++i) {
167165 ctx[i].type = type;
168- ctx[i].buffer = mr;
166+ ctx[i].buffer = buf + i * kMempoolChunkSize ;
167+ ctx[i].ref_mr = *mr;
169168 ctx[i].private_data = this ;
170- queue->Push (&ctx[i]);
169+ if (queue) {
170+ queue->Push (&ctx[i]);
171+ }
171172 }
172173 }
173174
@@ -192,32 +193,20 @@ struct Endpoint {
192193 << " , qp=" << id->qp ->qp_num << " , maxInline=" << kMaxInlineSize ;
193194 if (inited == 0 ) {
194195 rdma_provider = provider;
195- InitSendContextHelper (pd, start_ctx, &free_start_ctx, kStartDepth ,
196- kRendezvousStartContext );
197- InitSendContextHelper (pd, reply_ctx, &free_reply_ctx, kReplyDepth ,
198- kRendezvousReplyContext );
196+ InitWRContextHelper (pd, start_ctx, kStartDepth , kRendezvousStartContext ,
197+ &start_ctx_mr, 0 , &free_start_ctx);
198+ InitWRContextHelper (pd, reply_ctx, kReplyDepth , kRendezvousReplyContext ,
199+ &reply_ctx_mr, 0 , &free_reply_ctx);
200+ InitWRContextHelper (pd, rx_ctx, kRxDepth , kReceiveContext , &rx_ctx_mr,
201+ IBV_ACCESS_LOCAL_WRITE);
199202 }
200203
204+ // As only one QP will use the ctx buffer, other QPs just for imm receive.
205+ // It is OK for all QPs to repeate post recv the rx_ctx. The same buffers
206+ // but more rqe.
201207 for (int i = 0 ; i < kRxDepth ; ++i) {
202- if (inited == 0 ) {
203- void *buf;
204- aligned_malloc (reinterpret_cast <void **>(&buf), kMempoolChunkSize );
205- PS_CHECK (buf);
206- struct ibv_mr *mr =
207- ibv_reg_mr (pd, buf, kMempoolChunkSize , IBV_ACCESS_LOCAL_WRITE);
208- PS_CHECK (mr)
209- << " ibv_reg_mr failed: " << strerror (errno)
210- << " \n You can try to reduce BYTEPS_RDMA_START_DEPTH (default 128)"
211- << " or BYTEPS_RDMA_RX_DEPTH (default 2048)" ;
212-
213- rx_ctx[i].type = kReceiveContext ;
214- rx_ctx[i].buffer = mr;
215- rx_ctx[i].private_data = this ;
216- }
217- }
218- for (int i = 0 ; i < kRxDepth / QP_NUM; ++i) {
219208 if (inited < QP_NUM) {
220- PostRecv (&rx_ctx[i + inited * QP_NUM ], id);
209+ PostRecv (&rx_ctx[i], id);
221210 }
222211 }
223212 inited++;
@@ -247,9 +236,9 @@ struct Endpoint {
247236 memset (&wr, 0 , sizeof (wr));
248237
249238 struct ibv_sge sge;
250- sge.addr = reinterpret_cast <uint64_t >(ctx->buffer -> addr );
239+ sge.addr = reinterpret_cast <uint64_t >(ctx->buffer );
251240 sge.length = kMempoolChunkSize ;
252- sge.lkey = ctx->buffer ->lkey ;
241+ sge.lkey = ctx->ref_mr ->lkey ;
253242
254243 wr.wr_id = reinterpret_cast <uint64_t >(ctx);
255244 wr.next = nullptr ;
@@ -395,7 +384,7 @@ class RDMATransport : public Transport {
395384 endpoint_->free_start_ctx .WaitAndPop (&context);
396385
397386 RendezvousStart *req =
398- reinterpret_cast <RendezvousStart *>(context->buffer -> addr );
387+ reinterpret_cast <RendezvousStart *>(context->buffer );
399388 req->meta_len = msg_buf->inline_len ;
400389 req->origin_addr = reinterpret_cast <uint64_t >(msg_buf);
401390 req->data_num = msg_buf->data .size ();
@@ -406,7 +395,7 @@ class RDMATransport : public Transport {
406395
407396 struct ibv_sge sge;
408397 sge.addr = reinterpret_cast <uint64_t >(req);
409- sge.lkey = context->buffer ->lkey ;
398+ sge.lkey = context->ref_mr ->lkey ;
410399 sge.length = sizeof (RendezvousStart);
411400
412401 struct ibv_send_wr wr, *bad_wr = nullptr ;
@@ -475,7 +464,7 @@ class RDMATransport : public Transport {
475464 WRContext *reply_ctx_ptr = nullptr ;
476465 endpoint_->free_reply_ctx .WaitAndPop (&reply_ctx_ptr);
477466 auto *resp =
478- reinterpret_cast <RendezvousReply *>(reply_ctx_ptr->buffer -> addr );
467+ reinterpret_cast <RendezvousReply *>(reply_ctx_ptr->buffer );
479468
480469 // Populate reply with addresses and rkeys for both buffers
481470 resp->meta_addr = reinterpret_cast <uint64_t >(buf_ctx->meta_buffer );
@@ -500,7 +489,7 @@ class RDMATransport : public Transport {
500489 struct ibv_sge sge;
501490 sge.addr = reinterpret_cast <uint64_t >(resp);
502491 sge.length = sizeof (RendezvousReply);
503- sge.lkey = reply_ctx_ptr->buffer ->lkey ;
492+ sge.lkey = reply_ctx_ptr->ref_mr ->lkey ;
504493 struct ibv_send_wr wr, *bad_wr = nullptr ;
505494 memset (&wr, 0 , sizeof (wr));
506495 wr.wr_id = reinterpret_cast <uint64_t >(reply_ctx_ptr);
@@ -528,7 +517,7 @@ class RDMATransport : public Transport {
528517 WRContext *reply_ctx_ptr = nullptr ;
529518 endpoint_->free_reply_ctx .WaitAndPop (&reply_ctx_ptr);
530519 RendezvousReply *resp =
531- reinterpret_cast <RendezvousReply *>(reply_ctx_ptr->buffer -> addr );
520+ reinterpret_cast <RendezvousReply *>(reply_ctx_ptr->buffer );
532521
533522 // In GDR mode, client still uses single buffer logic,
534523 // so we populate the single addr/rkey
@@ -547,7 +536,7 @@ class RDMATransport : public Transport {
547536 struct ibv_sge sge;
548537 sge.addr = reinterpret_cast <uint64_t >(resp);
549538 sge.length = sizeof (RendezvousReply);
550- sge.lkey = reply_ctx_ptr->buffer ->lkey ;
539+ sge.lkey = reply_ctx_ptr->ref_mr ->lkey ;
551540 struct ibv_send_wr wr, *bad_wr = nullptr ;
552541 memset (&wr, 0 , sizeof (wr));
553542 wr.wr_id = reinterpret_cast <uint64_t >(reply_ctx_ptr);
0 commit comments