Skip to content

Commit e1f30a4

Browse files
committed
Introduce RdmaProvider and ERDMA support
Signed-off-by: Guangguan Wang <[email protected]>
1 parent 03deae2 commit e1f30a4

File tree

4 files changed

+189
-45
lines changed

4 files changed

+189
-45
lines changed

src/rdma_provider.h

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright 2025 Alibaba Group. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#ifndef RDMA_PROVIDER_H_
17+
#define RDMA_PROVIDER_H_
18+
19+
#ifdef DMLC_USE_RDMA
20+
21+
#include <string>
22+
#include <rdma/rdma_cma.h>
23+
24+
#include "dmlc/logging.h"
25+
#include "./ibvwarp.h"
26+
27+
namespace ps {
28+
29+
class RdmaProvider {
30+
public:
31+
static RdmaProvider *GetProvider(struct ibv_context *context);
32+
33+
virtual int InlineSize() = 0;
34+
virtual int SetQPLag(struct ibv_qp* qp, int port_num) = 0;
35+
virtual int ErrIgnore(struct ibv_wc *wc) = 0;
36+
37+
protected:
38+
RdmaProvider() {}
39+
};
40+
41+
class ErdmaProvider: public RdmaProvider {
42+
public:
43+
static inline RdmaProvider *Get() {
44+
static ErdmaProvider *inst_ptr = new ErdmaProvider();
45+
return inst_ptr;
46+
}
47+
48+
static inline const char *DevPrefix() { return "erdma"; }
49+
int InlineSize() override { return 96; }
50+
int SetQPLag(struct ibv_qp *qp, int port_num) override { return 1; }
51+
int ErrIgnore(struct ibv_wc *wc) override {
52+
if (wc->status == IBV_WC_SUCCESS ||
53+
(wc->status == IBV_WC_WR_FLUSH_ERR && !wc->vendor_err)) {
54+
return 1;
55+
}
56+
return 0;
57+
}
58+
59+
private:
60+
explicit ErdmaProvider() {}
61+
};
62+
63+
class Mlx5Provider: public RdmaProvider {
64+
public:
65+
66+
static inline RdmaProvider *Get() {
67+
static RdmaProvider *inst_ptr = new Mlx5Provider();
68+
return inst_ptr;
69+
}
70+
71+
static inline const char *DevPrefix() { return "mlx5"; }
72+
int InlineSize() override { return 512; }
73+
74+
int SetQPLag(struct ibv_qp *qp, int port_num) override {
75+
int ret = wrap_mlx5dv_modify_qp_lag_port(qp, port_num);
76+
if (ret == 1) {
77+
uint8_t set_port = 0xff, act_port = 0xff;
78+
wrap_mlx5dv_query_qp_lag_port(qp, &set_port, &act_port);
79+
PS_LOG(INFO) << "QP LAG Port: QP: " << qp->qp_num
80+
<< ", Modify Port: " << port_num
81+
<< ", Set to Port: " << static_cast<int>(set_port)
82+
<< ", Active Port: " << static_cast<int>(act_port);
83+
}
84+
return ret;
85+
}
86+
87+
int ErrIgnore(struct ibv_wc *wc) override {
88+
return wc->status == IBV_WC_SUCCESS;
89+
}
90+
91+
private:
92+
explicit Mlx5Provider() {
93+
if (wrap_ibv_symbols() != 1) {
94+
PS_LOG(WARNING) << "Load mlx5 symbols fails.";
95+
}
96+
}
97+
};
98+
99+
RdmaProvider *RdmaProvider::GetProvider(struct ibv_context *context) {
100+
const char *dev_name = ibv_get_device_name(context->device);
101+
if (strstr(dev_name, ErdmaProvider::DevPrefix())) {
102+
return ErdmaProvider::Get();
103+
} else {
104+
if (!strstr(dev_name, Mlx5Provider::DevPrefix())) {
105+
PS_LOG(WARNING) << "rdma device(" << dev_name
106+
<< ") with unknow provider, use mlx5 as default, maybe "
107+
"not compatible.";
108+
}
109+
return Mlx5Provider::Get();
110+
}
111+
}
112+
113+
} // namespace ps
114+
115+
#endif // DMLC_USE_RDMA
116+
#endif // RDMA_PROVIDER_H_

src/rdma_transport.h

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#include <unordered_map>
2828
#include <vector>
2929

30-
#include "./ibvwarp.h"
30+
#include "./rdma_provider.h"
3131
#include "./rdma_utils.h"
3232
#include "dmlc/logging.h"
3333
#include "ps/internal/multi_qp.h"
@@ -54,19 +54,19 @@ struct Endpoint {
5454
int kStartDepth = 128;
5555
int kRxDepth = 256;
5656
int kReplyDepth = kRxDepth;
57+
int kMaxInlineSize;
5758
WRContext *rx_ctx;
5859
WRContext *start_ctx;
5960
WRContext *reply_ctx;
6061

6162
ThreadsafeQueue<WRContext *> free_start_ctx;
6263
ThreadsafeQueue<WRContext *> free_reply_ctx;
6364

65+
RdmaProvider *rdma_provider = nullptr;
66+
6467
uint8_t inited = 0;
6568

6669
Endpoint() : node_id(Node::kEmpty), rx_ctx() {
67-
if (wrap_ibv_symbols() != 1) {
68-
PS_LOG(WARNING) << "Load mlx5 symbols fails.";
69-
}
7070
FOR_QPS {
7171
cm_ids[qpIndex] = nullptr;
7272
status_list[qpIndex] = IDLE;
@@ -171,7 +171,7 @@ struct Endpoint {
171171
}
172172
}
173173

174-
void Init(struct ibv_cq *cq, struct ibv_pd *pd, rdma_cm_id *id = nullptr) {
174+
void Init(struct ibv_cq *cq, struct ibv_pd *pd, RdmaProvider *provider, rdma_cm_id *id = nullptr) {
175175
struct ibv_qp_init_attr attr;
176176
memset(&attr, 0, sizeof(ibv_qp_init_attr));
177177
attr.send_cq = cq;
@@ -180,16 +180,18 @@ struct Endpoint {
180180
attr.cap.max_recv_wr = kRxDepth;
181181
attr.cap.max_send_sge = kSGEntry;
182182
attr.cap.max_recv_sge = kSGEntry;
183-
attr.cap.max_inline_data = 256;
183+
attr.cap.max_inline_data = provider->InlineSize();
184184
attr.qp_type = IBV_QPT_RC;
185185
attr.sq_sig_all = 0;
186186
PS_CHECK_EQ(rdma_create_qp(id, pd, &attr), 0)
187187
<< "Create RDMA queue pair failed: " << strerror(errno);
188188
id->pd = pd;
189+
kMaxInlineSize = attr.cap.max_inline_data;
189190

190191
PS_LOG(TRACE) << "qp created: pd=" << pd << " , cq=" << cq
191-
<< ", qp=" << id->qp->qp_num;
192+
<< ", qp=" << id->qp->qp_num << ", maxInline=" << kMaxInlineSize;
192193
if (inited == 0) {
194+
rdma_provider = provider;
193195
InitSendContextHelper(pd, start_ctx, &free_start_ctx, kStartDepth,
194196
kRendezvousStartContext);
195197
InitSendContextHelper(pd, reply_ctx, &free_reply_ctx, kReplyDepth,
@@ -227,21 +229,14 @@ struct Endpoint {
227229

228230
if (val == 1) {
229231
multi_qp_ = true;
232+
PS_CHECK(rdma_provider);
230233
FOR_QPS {
231234
int lag = 1 + qpIndex % 2;
232-
int ret = wrap_mlx5dv_modify_qp_lag_port(cm_ids[qpIndex]->qp, lag);
235+
int ret = rdma_provider->SetQPLag(cm_ids[qpIndex]->qp, lag);
233236
if (ret != 1) {
234-
PS_LOG(INFO) << "Failed to mlx5dv_modify_qp_lag_port qp ["
237+
PS_LOG(INFO) << "Failed to SetQPLag qp ["
235238
<< cm_ids[qpIndex]->qp->qp_num << "] to port: " << lag
236239
<< ", qp type: " << cm_ids[qpIndex]->qp->qp_type;
237-
} else {
238-
uint8_t set_port = 0xff, act_port = 0xff;
239-
wrap_mlx5dv_query_qp_lag_port(cm_ids[qpIndex]->qp, &set_port,
240-
&act_port);
241-
PS_LOG(INFO) << "QP LAG Port: QP: " << cm_ids[qpIndex]->qp->qp_num
242-
<< ", Modify Port: " << lag
243-
<< ", Set to Port: " << static_cast<int>(set_port)
244-
<< ", Active Port: " << static_cast<int>(act_port);
245240
}
246241
}
247242
}
@@ -310,6 +305,9 @@ class RDMATransport : public Transport {
310305
allocator_ = PS_CHECK_NOTNULL(allocator);
311306
pagesize_ = sysconf(_SC_PAGESIZE);
312307

308+
PS_CHECK_GT(endpoint_->kMaxInlineSize, 0);
309+
max_inline_size_ = endpoint_->kMaxInlineSize;
310+
313311
postoffice_ = postoffice;
314312
is_server_ = postoffice_->is_server();
315313
#ifdef STEPMESH_USE_GDR
@@ -342,32 +340,51 @@ class RDMATransport : public Transport {
342340
uint32_t rkey, uint32_t idx,
343341
bool inline_write = false,
344342
struct ibv_send_wr *prev_wr = nullptr) {
345-
struct ibv_sge sge;
346-
sge.addr = reinterpret_cast<uint64_t>(msg_buf->inline_buf);
347-
sge.length = msg_buf->inline_len;
348-
sge.lkey = allocator_->LocalKey(msg_buf->inline_buf);
349-
350-
struct ibv_send_wr wr = {}, *bad_wr = nullptr;
351-
memset(&wr, 0, sizeof(wr));
352-
wr.wr_id = reinterpret_cast<uint64_t>(msg_buf);
353-
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
354-
wr.next = nullptr;
355-
wr.imm_data = idx;
356-
wr.send_flags = IBV_SEND_SIGNALED;
357-
wr.sg_list = &sge;
358-
wr.num_sge = 1;
359-
wr.wr.rdma.remote_addr = remote_addr;
360-
wr.wr.rdma.rkey = rkey;
343+
struct ibv_send_wr wr[kRdmaMaxWRs], *bad_wr = nullptr;
344+
struct ibv_sge sge[kRdmaMaxWRs];
345+
uint32_t lkey = allocator_->LocalKey(msg_buf->inline_buf);
346+
size_t last_len = 0, offset = 0, len = 0;
347+
int num_wr = 1, last_flags = 0;
361348

362349
if (inline_write) {
363-
wr.send_flags |= IBV_SEND_INLINE;
350+
num_wr = DivUp(msg_buf->inline_len, max_inline_size_);
351+
last_len = msg_buf->inline_len % max_inline_size_;
352+
last_flags = IBV_SEND_INLINE | IBV_SEND_SIGNALED;
353+
} else {
354+
num_wr = 1;
355+
last_len = msg_buf->inline_len;
356+
last_flags = IBV_SEND_SIGNALED;
357+
}
358+
359+
PS_CHECK_LE(num_wr, kRdmaMaxWRs)
360+
<< "too many wrs, send_len: " << msg_buf->inline_len
361+
<< " , max_inline: " << max_inline_size_;
362+
memset(wr, 0, sizeof(struct ibv_send_wr) * num_wr);
363+
364+
for (int i = 0; i < num_wr; ++i) {
365+
bool is_last = (i == (num_wr - 1));
366+
len = is_last ? last_len : max_inline_size_;
367+
sge[i].addr = reinterpret_cast<uint64_t>(msg_buf->inline_buf + offset);
368+
sge[i].length = len;
369+
sge[i].lkey = lkey;
370+
371+
wr[i].wr_id = reinterpret_cast<uint64_t>(msg_buf);
372+
wr[i].opcode = is_last ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE;
373+
wr[i].next = is_last ? nullptr : &wr[i + 1];
374+
wr[i].imm_data = is_last ? idx : 0;
375+
wr[i].send_flags = is_last ? last_flags : 0;
376+
wr[i].sg_list = &sge[i];
377+
wr[i].num_sge = 1;
378+
wr[i].wr.rdma.remote_addr = remote_addr + offset;
379+
wr[i].wr.rdma.rkey = rkey;
380+
offset += len;
364381
}
365382

366383
if (prev_wr == nullptr) {
367-
PS_CHECK_EQ(ibv_post_send(endpoint_->cm_ids[0]->qp, &wr, &bad_wr), 0)
384+
PS_CHECK_EQ(ibv_post_send(endpoint_->cm_ids[0]->qp, wr, &bad_wr), 0)
368385
<< "ibv_post_send failed.";
369386
} else {
370-
prev_wr->next = &wr;
387+
prev_wr->next = &wr[0];
371388
PS_CHECK_EQ(ibv_post_send(endpoint_->cm_ids[0]->qp, prev_wr, &bad_wr), 0)
372389
<< "ibv_post_send failed.";
373390
}
@@ -839,6 +856,7 @@ class RDMATransport : public Transport {
839856

840857
protected:
841858
size_t pagesize_ = 4096;
859+
size_t max_inline_size_ = 0;
842860
Endpoint *endpoint_;
843861
MemoryAllocator *allocator_;
844862
BackendMemoryAllocator *mem_allocator_;

src/rdma_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ static const int kSGEntry = 1;
6262
static const int kTimeoutms = 1000;
6363
static const int kRdmaListenBacklog = 128;
6464
static const int kMaxHostnameLength = 16;
65+
static const int kRdmaMaxWRs = 12;
6566

6667
// should have the same prefix with BytePS shared memory
6768
// for pcie reduce: BytePS_Pcie_{pcie_id}_ShM_{JOB_ID}_{BYTEPS_KEY}

src/rdma_van.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,9 @@ class RDMAVan : public Van {
732732
context_ = context;
733733
PS_CHECK(context_) << "ibv_context* empty";
734734

735+
provider_ = RdmaProvider::GetProvider(context_);
736+
PS_CHECK(provider_) << "failed to get rdma provider";
737+
735738
pd_ = ibv_alloc_pd(context_);
736739
PS_CHECK(pd_) << "Failed to allocate protection domain";
737740

@@ -774,13 +777,18 @@ class RDMAVan : public Van {
774777

775778
PS_CHECK_GE(ne, 0);
776779
for (int i = 0; i < ne; ++i) {
777-
PS_CHECK(wc[i].status == IBV_WC_SUCCESS)
778-
<< "Failed status \n"
779-
<< ibv_wc_status_str(wc[i].status) << " " << wc[i].status << " "
780-
<< static_cast<uint64_t>(wc[i].wr_id) << " " << wc[i].vendor_err
781-
<< " " << wc[i].opcode << " "
782-
<< (wc[i].opcode == IBV_WC_RECV ? "RECV" : "OTHER")
783-
<< " postoffice ptr: " << reinterpret_cast<void *>(postoffice_);
780+
if (wc[i].status != IBV_WC_SUCCESS) {
781+
if (provider_->ErrIgnore(&wc[i])) {
782+
continue;
783+
}
784+
PS_LOG(FATAL) << "Failed status \n"
785+
<< ibv_wc_status_str(wc[i].status) << " " << wc[i].status << " "
786+
<< static_cast<uint64_t>(wc[i].wr_id) << " " << wc[i].vendor_err
787+
<< " " << wc[i].opcode << " "
788+
<< (wc[i].opcode == IBV_WC_RECV ? "RECV" : "OTHER")
789+
<< " postoffice ptr: " << reinterpret_cast<void *>(postoffice_);
790+
}
791+
784792

785793
// IBV_WC_RDMA_WRITE use msg_buf as the wr_id
786794
// so there won't be context and endpoint for this op
@@ -991,7 +999,7 @@ class RDMAVan : public Van {
991999
InitContext(id->verbs);
9921000
}
9931001

994-
endpoint->Init(cq_, pd_, id);
1002+
endpoint->Init(cq_, pd_, provider_, id);
9951003

9961004
bool is_local_node =
9971005
disable_ipc_
@@ -1045,7 +1053,7 @@ class RDMAVan : public Van {
10451053
if (context_ == nullptr) {
10461054
InitContext(id->verbs);
10471055
}
1048-
endpoint->Init(cq_, pd_, id);
1056+
endpoint->Init(cq_, pd_, provider_, id);
10491057
endpoint->inComingCount++;
10501058
RequestContext ctx;
10511059
ctx.node = static_cast<uint32_t>(my_node_.id);
@@ -1122,6 +1130,7 @@ class RDMAVan : public Van {
11221130

11231131
struct rdma_event_channel *event_channel_ = nullptr;
11241132
struct ibv_context *context_ = nullptr;
1133+
RdmaProvider *provider_ = nullptr;
11251134

11261135
// ibverbs protection domain
11271136
struct ibv_pd *pd_ = nullptr;

0 commit comments

Comments
 (0)