Skip to content

Commit ddc124b

Browse files
committed
Optimize mr usage in WRContext
Signed-off-by: Guangguan Wang <[email protected]>
1 parent e1f30a4 commit ddc124b

File tree

3 files changed

+55
-66
lines changed

3 files changed

+55
-66
lines changed

src/rdma_transport.h

Lines changed: 51 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,17 @@ struct Endpoint {
5252

5353
int inComingCount = 0;
5454
int kStartDepth = 128;
55-
int kRxDepth = 256;
55+
int kRxDepth = 128;
5656
int kReplyDepth = kRxDepth;
5757
int kMaxInlineSize;
5858
WRContext *rx_ctx;
5959
WRContext *start_ctx;
6060
WRContext *reply_ctx;
6161

62+
struct ibv_mr *rx_ctx_mr = nullptr;
63+
struct ibv_mr *start_ctx_mr = nullptr;
64+
struct ibv_mr *reply_ctx_mr = nullptr;
65+
6266
ThreadsafeQueue<WRContext *> free_start_ctx;
6367
ThreadsafeQueue<WRContext *> free_reply_ctx;
6468

@@ -94,27 +98,24 @@ struct Endpoint {
9498
}
9599

96100
~Endpoint() {
97-
for (int i = 0; i < kRxDepth; ++i) {
98-
if (!(rx_ctx[i].buffer)) {
99-
continue;
100-
}
101-
free(rx_ctx[i].buffer->addr);
102-
PS_CHECK_EQ(ibv_dereg_mr(rx_ctx[i].buffer), 0);
101+
if (rx_ctx_mr) {
102+
void *buf = rx_ctx_mr->addr;
103+
PS_CHECK_EQ(ibv_dereg_mr(rx_ctx_mr), 0);
104+
free(buf);
103105
}
104106

105-
for (int i = 0; i < kStartDepth; ++i) {
106-
if (start_ctx[i].buffer) {
107-
free(start_ctx[i].buffer->addr);
108-
PS_CHECK_EQ(ibv_dereg_mr(start_ctx[i].buffer), 0);
109-
}
107+
if (start_ctx_mr) {
108+
void *buf = start_ctx_mr->addr;
109+
PS_CHECK_EQ(ibv_dereg_mr(start_ctx_mr), 0);
110+
free(buf);
110111
}
111112

112-
for (int i = 0; i < kReplyDepth; ++i) {
113-
if (reply_ctx[i].buffer) {
114-
free(reply_ctx[i].buffer->addr);
115-
PS_CHECK_EQ(ibv_dereg_mr(reply_ctx[i].buffer), 0);
116-
}
113+
if (reply_ctx_mr) {
114+
void *buf = reply_ctx_mr->addr;
115+
PS_CHECK_EQ(ibv_dereg_mr(reply_ctx_mr), 0);
116+
free(buf);
117117
}
118+
118119
FOR_QPS {
119120
rdma_destroy_qp(cm_ids[qpIndex]);
120121
PS_CHECK_EQ(rdma_destroy_id(cm_ids[qpIndex]), 0) << strerror(errno);
@@ -150,24 +151,24 @@ struct Endpoint {
150151

151152
void SetNodeID(int id) { node_id = id; }
152153

153-
void InitSendContextHelper(struct ibv_pd *pd, WRContext *ctx,
154-
ThreadsafeQueue<WRContext *> *queue, size_t num,
155-
WRContextType type) {
156-
for (size_t i = 0; i < num; ++i) {
157-
void *buf;
158-
aligned_malloc(reinterpret_cast<void **>(&buf), kMempoolChunkSize);
159-
PS_CHECK(buf);
160-
struct ibv_mr *mr = ibv_reg_mr(pd, buf, kMempoolChunkSize, 0);
161-
PS_CHECK(mr)
162-
<< "ibv_reg_mr failed: " << strerror(errno)
163-
<< "\nYou can try to reduce BYTEPS_RDMA_START_DEPTH (current "
164-
<< kStartDepth << ") or BYTEPS_RDMA_RX_DEPTH (current " << kRxDepth
165-
<< ").";
154+
void InitWRContextHelper(struct ibv_pd *pd, WRContext *ctx,
155+
size_t num, WRContextType type,
156+
struct ibv_mr **mr, unsigned int access,
157+
ThreadsafeQueue<WRContext *> *queue = nullptr) {
158+
char *buf;
159+
aligned_malloc(reinterpret_cast<void **>(&buf), kMempoolChunkSize * num);
160+
PS_CHECK(buf);
161+
*mr = ibv_reg_mr(pd, buf, kMempoolChunkSize * num, access);
162+
PS_CHECK(*mr) << "ibv_reg_mr failed: " << strerror(errno);
166163

164+
for (size_t i = 0; i < num; ++i) {
167165
ctx[i].type = type;
168-
ctx[i].buffer = mr;
166+
ctx[i].buffer = buf + i * kMempoolChunkSize;
167+
ctx[i].ref_mr = *mr;
169168
ctx[i].private_data = this;
170-
queue->Push(&ctx[i]);
169+
if (queue) {
170+
queue->Push(&ctx[i]);
171+
}
171172
}
172173
}
173174

@@ -192,32 +193,20 @@ struct Endpoint {
192193
<< ", qp=" << id->qp->qp_num << ", maxInline=" << kMaxInlineSize;
193194
if (inited == 0) {
194195
rdma_provider = provider;
195-
InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth,
196-
kRendezvousStartContext);
197-
InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth,
198-
kRendezvousReplyContext);
196+
InitWRContextHelper(pd, start_ctx, kStartDepth, kRendezvousStartContext,
197+
&start_ctx_mr, 0, &free_start_ctx);
198+
InitWRContextHelper(pd, reply_ctx, kReplyDepth, kRendezvousReplyContext,
199+
&reply_ctx_mr, 0, &free_reply_ctx);
200+
InitWRContextHelper(pd, rx_ctx, kRxDepth, kReceiveContext, &rx_ctx_mr,
201+
IBV_ACCESS_LOCAL_WRITE);
199202
}
200203

204+
// As only one QP will use the ctx buffer, other QPs just for imm receive.
205+
// It is OK for all QPs to repeate post recv the rx_ctx. The same buffers
206+
// but more rqe.
201207
for (int i = 0; i < kRxDepth; ++i) {
202-
if (inited == 0) {
203-
void *buf;
204-
aligned_malloc(reinterpret_cast<void **>(&buf), kMempoolChunkSize);
205-
PS_CHECK(buf);
206-
struct ibv_mr *mr =
207-
ibv_reg_mr(pd, buf, kMempoolChunkSize, IBV_ACCESS_LOCAL_WRITE);
208-
PS_CHECK(mr)
209-
<< "ibv_reg_mr failed: " << strerror(errno)
210-
<< "\nYou can try to reduce BYTEPS_RDMA_START_DEPTH (default 128)"
211-
<< " or BYTEPS_RDMA_RX_DEPTH (default 2048)";
212-
213-
rx_ctx[i].type = kReceiveContext;
214-
rx_ctx[i].buffer = mr;
215-
rx_ctx[i].private_data = this;
216-
}
217-
}
218-
for (int i = 0; i < kRxDepth / QP_NUM; ++i) {
219208
if (inited < QP_NUM) {
220-
PostRecv(&rx_ctx[i + inited * QP_NUM], id);
209+
PostRecv(&rx_ctx[i], id);
221210
}
222211
}
223212
inited++;
@@ -247,9 +236,9 @@ struct Endpoint {
247236
memset(&wr, 0, sizeof(wr));
248237

249238
struct ibv_sge sge;
250-
sge.addr = reinterpret_cast<uint64_t>(ctx->buffer->addr);
239+
sge.addr = reinterpret_cast<uint64_t>(ctx->buffer);
251240
sge.length = kMempoolChunkSize;
252-
sge.lkey = ctx->buffer->lkey;
241+
sge.lkey = ctx->ref_mr->lkey;
253242

254243
wr.wr_id = reinterpret_cast<uint64_t>(ctx);
255244
wr.next = nullptr;
@@ -395,7 +384,7 @@ class RDMATransport : public Transport {
395384
endpoint_->free_start_ctx.WaitAndPop(&context);
396385

397386
RendezvousStart *req =
398-
reinterpret_cast<RendezvousStart *>(context->buffer->addr);
387+
reinterpret_cast<RendezvousStart *>(context->buffer);
399388
req->meta_len = msg_buf->inline_len;
400389
req->origin_addr = reinterpret_cast<uint64_t>(msg_buf);
401390
req->data_num = msg_buf->data.size();
@@ -406,7 +395,7 @@ class RDMATransport : public Transport {
406395

407396
struct ibv_sge sge;
408397
sge.addr = reinterpret_cast<uint64_t>(req);
409-
sge.lkey = context->buffer->lkey;
398+
sge.lkey = context->ref_mr->lkey;
410399
sge.length = sizeof(RendezvousStart);
411400

412401
struct ibv_send_wr wr, *bad_wr = nullptr;
@@ -475,7 +464,7 @@ class RDMATransport : public Transport {
475464
WRContext *reply_ctx_ptr = nullptr;
476465
endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx_ptr);
477466
auto *resp =
478-
reinterpret_cast<RendezvousReply *>(reply_ctx_ptr->buffer->addr);
467+
reinterpret_cast<RendezvousReply *>(reply_ctx_ptr->buffer);
479468

480469
// Populate reply with addresses and rkeys for both buffers
481470
resp->meta_addr = reinterpret_cast<uint64_t>(buf_ctx->meta_buffer);
@@ -500,7 +489,7 @@ class RDMATransport : public Transport {
500489
struct ibv_sge sge;
501490
sge.addr = reinterpret_cast<uint64_t>(resp);
502491
sge.length = sizeof(RendezvousReply);
503-
sge.lkey = reply_ctx_ptr->buffer->lkey;
492+
sge.lkey = reply_ctx_ptr->ref_mr->lkey;
504493
struct ibv_send_wr wr, *bad_wr = nullptr;
505494
memset(&wr, 0, sizeof(wr));
506495
wr.wr_id = reinterpret_cast<uint64_t>(reply_ctx_ptr);
@@ -528,7 +517,7 @@ class RDMATransport : public Transport {
528517
WRContext *reply_ctx_ptr = nullptr;
529518
endpoint_->free_reply_ctx.WaitAndPop(&reply_ctx_ptr);
530519
RendezvousReply *resp =
531-
reinterpret_cast<RendezvousReply *>(reply_ctx_ptr->buffer->addr);
520+
reinterpret_cast<RendezvousReply *>(reply_ctx_ptr->buffer);
532521

533522
// In GDR mode, client still uses single buffer logic,
534523
// so we populate the single addr/rkey
@@ -547,7 +536,7 @@ class RDMATransport : public Transport {
547536
struct ibv_sge sge;
548537
sge.addr = reinterpret_cast<uint64_t>(resp);
549538
sge.length = sizeof(RendezvousReply);
550-
sge.lkey = reply_ctx_ptr->buffer->lkey;
539+
sge.lkey = reply_ctx_ptr->ref_mr->lkey;
551540
struct ibv_send_wr wr, *bad_wr = nullptr;
552541
memset(&wr, 0, sizeof(wr));
553542
wr.wr_id = reinterpret_cast<uint64_t>(reply_ctx_ptr);

src/rdma_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ class BackendMemoryAllocator {
228228

229229
struct WRContext {
230230
WRContextType type;
231-
struct ibv_mr *buffer;
231+
void *buffer;
232+
struct ibv_mr *ref_mr;
232233
void *private_data;
233234
};
234235

src/rdma_van.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -829,16 +829,15 @@ class RDMAVan : public Van {
829829
case IBV_WC_RECV: {
830830
PS_CHECK(wc[i].wc_flags & IBV_WC_WITH_IMM);
831831
uint32_t imm = wc[i].imm_data;
832-
struct ibv_mr *mr = context->buffer;
833832

834833
if (imm == kRendezvousStart) {
835834
RendezvousStart *req =
836-
reinterpret_cast<RendezvousStart *>(mr->addr);
835+
reinterpret_cast<RendezvousStart *>(context->buffer);
837836
auto trans = PS_CHECK_NOTNULL(endpoint->GetTransport());
838837
trans->SendRendezvousReply(req, addr_pool_);
839838
} else if (imm == kRendezvousReply) {
840839
RendezvousReply *resp =
841-
reinterpret_cast<RendezvousReply *>(mr->addr);
840+
reinterpret_cast<RendezvousReply *>(context->buffer);
842841

843842
uint64_t origin_addr = resp->origin_addr;
844843
uint32_t idx = resp->idx;

0 commit comments

Comments
 (0)