Skip to content

Commit 03deae2

Browse files
authored
Merge pull request #26 from niehao100/public
Add our in-door api SimpleNotify for rapidly sync between GPU stream and rdma write/poll cq op.
2 parents 4045dff + 68ebf73 commit 03deae2

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,6 @@ test: $(TEST)
104104

105105
af:
106106
@mkdir -p cmake_build
107-
@cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=/usr/bin/python3 -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j
107+
@cd cmake_build; cmake .. -DCMAKE_CUDA_COMPILER=$(CMAKE_CUDA_COMPILER) -DPython_EXECUTABLE=$(shell which python3) -DCUDA_TOOLKIT_ROOT_DIR=$(CUDA_TOOLKIT_ROOT_DIR); make -j
108108
@mkdir -p build
109109

fserver/csrc/private.hpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,72 @@
1313

1414
using namespace ps;
1515
#ifdef DMLC_USE_CUDA
16-
void pybind_private(py::module &m){}
16+
class SimpleNotify{
17+
private:
18+
int notify_cnt = 1;
19+
CUdeviceptr dflag;
20+
uint32_t* hflag;
21+
std::thread th_;
22+
std::future<std::vector<ServerDataBatch>> fut;
23+
public:
24+
void init() {
25+
cudaHostAlloc(&hflag, sizeof(uint32_t), cudaHostAllocMapped);
26+
cudaHostGetDevicePointer((void**)&dflag, (void*)hflag, 0);
27+
}
28+
29+
// for worker
30+
void wait_event_done(){
31+
if (th_.joinable()) {
32+
th_.join();
33+
}
34+
}
35+
36+
// for worker
37+
void stream_wait_event(int handler) {
38+
auto stream = at::cuda::getCurrentCUDAStream();
39+
cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ);
40+
th_ = std::thread([handler, this]{
41+
fworker_->Wait(handler);
42+
*(this->hflag) = this->notify_cnt;
43+
++(this->notify_cnt);
44+
});
45+
}
46+
47+
void block_now_stream() {
48+
auto stream = at::cuda::getCurrentCUDAStream();
49+
cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ);
50+
}
51+
52+
// for server
53+
void block_now_stream_and_get_batch() {
54+
auto stream = at::cuda::getCurrentCUDAStream();
55+
cuStreamWaitValue32((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ);
56+
fut = std::async(std::launch::async, [this]{
57+
auto ret = get_batch();
58+
*(this->hflag) = this->notify_cnt;
59+
++(this->notify_cnt);
60+
return ret;
61+
});
62+
}
63+
64+
// for server
65+
std::vector<ServerDataBatch> get_future_batch_data(){
66+
return fut.get();
67+
}
68+
};
69+
70+
void pybind_private(py::module &m){
71+
py::class_<SimpleNotify>(m, "SimpleNotify")
72+
.def(py::init<>())
73+
.def("init", &SimpleNotify::init)
74+
.def("block_now_stream_and_get_batch", &SimpleNotify::block_now_stream_and_get_batch)
75+
.def("get_future_batch_data", &SimpleNotify::get_future_batch_data)
76+
.def("block_now_stream", &SimpleNotify::block_now_stream)
77+
.def("wait_event_done", &SimpleNotify::wait_event_done)
78+
.def("stream_wait_event", &SimpleNotify::stream_wait_event);
79+
}
1780
#else
1881
void pybind_private(py::module &m){}
1982
#endif //DMLC_USE_CUDA
2083

21-
#endif //PRIVATE_OPS_
84+
#endif //PRIVATE_OPS_

fserver/csrc/public.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ int instance_id_ = 0;
3838
int num_worker_ = 1;
3939
uint64_t worker_mask_ = 0x1;
4040

41-
typedef std::tuple<uint64_t, std::vector<torch::Tensor>, std::vector<uint64_t>>
42-
ServerDataBatch;
43-
4441
std::mutex mu_;
4542
uint64_t handler_counter_ = 0;
4643
std::unordered_map<uint64_t, AFTensorMeta> meta_map_;

fserver/csrc/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@
2020

2121
#ifndef UTIL_H_
2222
#define UTIL_H_
23+
typedef std::tuple<uint64_t, std::vector<torch::Tensor>, std::vector<uint64_t>>
24+
ServerDataBatch;
2325
#endif // UTIL_H_

0 commit comments

Comments
 (0)