2727#include < unordered_map>
2828#include < vector>
2929
30- #include " ./ibvwarp .h"
30+ #include " ./rdma_provider .h"
3131#include " ./rdma_utils.h"
3232#include " dmlc/logging.h"
3333#include " ps/internal/multi_qp.h"
@@ -54,19 +54,19 @@ struct Endpoint {
5454 int kStartDepth = 128 ;
5555 int kRxDepth = 256 ;
5656 int kReplyDepth = kRxDepth ;
57+ int kMaxInlineSize ;
5758 WRContext *rx_ctx;
5859 WRContext *start_ctx;
5960 WRContext *reply_ctx;
6061
6162 ThreadsafeQueue<WRContext *> free_start_ctx;
6263 ThreadsafeQueue<WRContext *> free_reply_ctx;
6364
65+ RdmaProvider *rdma_provider = nullptr ;
66+
6467 uint8_t inited = 0 ;
6568
6669 Endpoint () : node_id(Node::kEmpty ), rx_ctx() {
67- if (wrap_ibv_symbols () != 1 ) {
68- PS_LOG (WARNING) << " Load mlx5 symbols fails." ;
69- }
7070 FOR_QPS {
7171 cm_ids[qpIndex] = nullptr ;
7272 status_list[qpIndex] = IDLE;
@@ -171,7 +171,7 @@ struct Endpoint {
171171 }
172172 }
173173
174- void Init (struct ibv_cq *cq, struct ibv_pd *pd, rdma_cm_id *id = nullptr ) {
174+ void Init (struct ibv_cq *cq, struct ibv_pd *pd, RdmaProvider *provider, rdma_cm_id *id = nullptr ) {
175175 struct ibv_qp_init_attr attr;
176176 memset (&attr, 0 , sizeof (ibv_qp_init_attr));
177177 attr.send_cq = cq;
@@ -180,16 +180,18 @@ struct Endpoint {
180180 attr.cap .max_recv_wr = kRxDepth ;
181181 attr.cap .max_send_sge = kSGEntry ;
182182 attr.cap .max_recv_sge = kSGEntry ;
183- attr.cap .max_inline_data = 256 ;
183+ attr.cap .max_inline_data = provider-> InlineSize () ;
184184 attr.qp_type = IBV_QPT_RC;
185185 attr.sq_sig_all = 0 ;
186186 PS_CHECK_EQ (rdma_create_qp (id, pd, &attr), 0 )
187187 << " Create RDMA queue pair failed: " << strerror (errno);
188188 id->pd = pd;
189+ kMaxInlineSize = attr.cap .max_inline_data ;
189190
190191 PS_LOG (TRACE) << " qp created: pd=" << pd << " , cq=" << cq
191- << " , qp=" << id->qp ->qp_num ;
192+ << " , qp=" << id->qp ->qp_num << " , maxInline= " << kMaxInlineSize ;
192193 if (inited == 0 ) {
194+ rdma_provider = provider;
193195 InitSendContextHelper (pd, start_ctx, &free_start_ctx, kStartDepth ,
194196 kRendezvousStartContext );
195197 InitSendContextHelper (pd, reply_ctx, &free_reply_ctx, kReplyDepth ,
@@ -227,21 +229,14 @@ struct Endpoint {
227229
228230 if (val == 1 ) {
229231 multi_qp_ = true ;
232+ PS_CHECK (rdma_provider);
230233 FOR_QPS {
231234 int lag = 1 + qpIndex % 2 ;
232- int ret = wrap_mlx5dv_modify_qp_lag_port (cm_ids[qpIndex]->qp , lag);
235+ int ret = rdma_provider-> SetQPLag (cm_ids[qpIndex]->qp , lag);
233236 if (ret != 1 ) {
234- PS_LOG (INFO) << " Failed to mlx5dv_modify_qp_lag_port qp ["
237+ PS_LOG (INFO) << " Failed to SetQPLag qp ["
235238 << cm_ids[qpIndex]->qp ->qp_num << " ] to port: " << lag
236239 << " , qp type: " << cm_ids[qpIndex]->qp ->qp_type ;
237- } else {
238- uint8_t set_port = 0xff , act_port = 0xff ;
239- wrap_mlx5dv_query_qp_lag_port (cm_ids[qpIndex]->qp , &set_port,
240- &act_port);
241- PS_LOG (INFO) << " QP LAG Port: QP: " << cm_ids[qpIndex]->qp ->qp_num
242- << " , Modify Port: " << lag
243- << " , Set to Port: " << static_cast <int >(set_port)
244- << " , Active Port: " << static_cast <int >(act_port);
245240 }
246241 }
247242 }
@@ -310,6 +305,9 @@ class RDMATransport : public Transport {
310305 allocator_ = PS_CHECK_NOTNULL (allocator);
311306 pagesize_ = sysconf (_SC_PAGESIZE);
312307
308+ PS_CHECK_GT (endpoint_->kMaxInlineSize , 0 );
309+ max_inline_size_ = endpoint_->kMaxInlineSize ;
310+
313311 postoffice_ = postoffice;
314312 is_server_ = postoffice_->is_server ();
315313#ifdef STEPMESH_USE_GDR
@@ -342,32 +340,51 @@ class RDMATransport : public Transport {
342340 uint32_t rkey, uint32_t idx,
343341 bool inline_write = false ,
344342 struct ibv_send_wr *prev_wr = nullptr ) {
345- struct ibv_sge sge;
346- sge.addr = reinterpret_cast <uint64_t >(msg_buf->inline_buf );
347- sge.length = msg_buf->inline_len ;
348- sge.lkey = allocator_->LocalKey (msg_buf->inline_buf );
349-
350- struct ibv_send_wr wr = {}, *bad_wr = nullptr ;
351- memset (&wr, 0 , sizeof (wr));
352- wr.wr_id = reinterpret_cast <uint64_t >(msg_buf);
353- wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
354- wr.next = nullptr ;
355- wr.imm_data = idx;
356- wr.send_flags = IBV_SEND_SIGNALED;
357- wr.sg_list = &sge;
358- wr.num_sge = 1 ;
359- wr.wr .rdma .remote_addr = remote_addr;
360- wr.wr .rdma .rkey = rkey;
343+ struct ibv_send_wr wr[kRdmaMaxWRs ], *bad_wr = nullptr ;
344+ struct ibv_sge sge[kRdmaMaxWRs ];
345+ uint32_t lkey = allocator_->LocalKey (msg_buf->inline_buf );
346+ size_t last_len = 0 , offset = 0 , len = 0 ;
347+ int num_wr = 1 , last_flags = 0 ;
361348
362349 if (inline_write) {
363- wr.send_flags |= IBV_SEND_INLINE;
350+ num_wr = DivUp (msg_buf->inline_len , max_inline_size_);
351+ last_len = msg_buf->inline_len % max_inline_size_;
352+ last_flags = IBV_SEND_INLINE | IBV_SEND_SIGNALED;
353+ } else {
354+ num_wr = 1 ;
355+ last_len = msg_buf->inline_len ;
356+ last_flags = IBV_SEND_SIGNALED;
357+ }
358+
359+ PS_CHECK_LE (num_wr, kRdmaMaxWRs )
360+ << " too many wrs, send_len: " << msg_buf->inline_len
361+ << " , max_inline: " << max_inline_size_;
362+ memset (wr, 0 , sizeof (struct ibv_send_wr ) * num_wr);
363+
364+ for (int i = 0 ; i < num_wr; ++i) {
365+ bool is_last = (i == (num_wr - 1 ));
366+ len = is_last ? last_len : max_inline_size_;
367+ sge[i].addr = reinterpret_cast <uint64_t >(msg_buf->inline_buf + offset);
368+ sge[i].length = len;
369+ sge[i].lkey = lkey;
370+
371+ wr[i].wr_id = reinterpret_cast <uint64_t >(msg_buf);
372+ wr[i].opcode = is_last ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE;
373+ wr[i].next = is_last ? nullptr : &wr[i + 1 ];
374+ wr[i].imm_data = is_last ? idx : 0 ;
375+ wr[i].send_flags = is_last ? last_flags : 0 ;
376+ wr[i].sg_list = &sge[i];
377+ wr[i].num_sge = 1 ;
378+ wr[i].wr .rdma .remote_addr = remote_addr + offset;
379+ wr[i].wr .rdma .rkey = rkey;
380+ offset += len;
364381 }
365382
366383 if (prev_wr == nullptr ) {
367- PS_CHECK_EQ (ibv_post_send (endpoint_->cm_ids [0 ]->qp , & wr, &bad_wr), 0 )
384+ PS_CHECK_EQ (ibv_post_send (endpoint_->cm_ids [0 ]->qp , wr, &bad_wr), 0 )
368385 << " ibv_post_send failed." ;
369386 } else {
370- prev_wr->next = ≀
387+ prev_wr->next = &wr[ 0 ] ;
371388 PS_CHECK_EQ (ibv_post_send (endpoint_->cm_ids [0 ]->qp , prev_wr, &bad_wr), 0 )
372389 << " ibv_post_send failed." ;
373390 }
@@ -839,6 +856,7 @@ class RDMATransport : public Transport {
839856
840857 protected:
841858 size_t pagesize_ = 4096 ;
859+ size_t max_inline_size_ = 0 ;
842860 Endpoint *endpoint_;
843861 MemoryAllocator *allocator_;
844862 BackendMemoryAllocator *mem_allocator_;
0 commit comments