1313
1414using namespace ps ;
1515#ifdef DMLC_USE_CUDA
16- void pybind_private (py::module &m){}
16+ class SimpleNotify {
17+ private:
18+ int notify_cnt = 1 ;
19+ CUdeviceptr dflag;
20+ uint32_t * hflag;
21+ std::thread th_;
22+ std::future<std::vector<ServerDataBatch>> fut;
23+ public:
24+ void init () {
25+ cudaHostAlloc (&hflag, sizeof (uint32_t ), cudaHostAllocMapped);
26+ cudaHostGetDevicePointer ((void **)&dflag, (void *)hflag, 0 );
27+ }
28+
29+ // for worker
30+ void wait_event_done (){
31+ if (th_.joinable ()) {
32+ th_.join ();
33+ }
34+ }
35+
36+ // for worker
37+ void stream_wait_event (int handler) {
38+ auto stream = at::cuda::getCurrentCUDAStream ();
39+ cuStreamWaitValue32 ((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ);
40+ th_ = std::thread ([handler, this ]{
41+ fworker_->Wait (handler);
42+ *(this ->hflag ) = this ->notify_cnt ;
43+ ++(this ->notify_cnt );
44+ });
45+ }
46+
47+ void block_now_stream () {
48+ auto stream = at::cuda::getCurrentCUDAStream ();
49+ cuStreamWaitValue32 ((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ);
50+ }
51+
52+ // for server
53+ void block_now_stream_and_get_batch () {
54+ auto stream = at::cuda::getCurrentCUDAStream ();
55+ cuStreamWaitValue32 ((CUstream)stream, dflag, notify_cnt, CU_STREAM_WAIT_VALUE_EQ);
56+ fut = std::async (std::launch::async, [this ]{
57+ auto ret = get_batch ();
58+ *(this ->hflag ) = this ->notify_cnt ;
59+ ++(this ->notify_cnt );
60+ return ret;
61+ });
62+ }
63+
64+ // for server
65+ std::vector<ServerDataBatch> get_future_batch_data (){
66+ return fut.get ();
67+ }
68+ };
69+
70+ void pybind_private (py::module &m){
71+ py::class_<SimpleNotify>(m, " SimpleNotify" )
72+ .def (py::init<>())
73+ .def (" init" , &SimpleNotify::init)
74+ .def (" block_now_stream_and_get_batch" , &SimpleNotify::block_now_stream_and_get_batch)
75+ .def (" get_future_batch_data" , &SimpleNotify::get_future_batch_data)
76+ .def (" block_now_stream" , &SimpleNotify::block_now_stream)
77+ .def (" wait_event_done" , &SimpleNotify::wait_event_done)
78+ .def (" stream_wait_event" , &SimpleNotify::stream_wait_event);
79+ }
1780#else
1881void pybind_private (py::module &m){}
1982#endif // DMLC_USE_CUDA
2083
21- #endif // PRIVATE_OPS_
84+ #endif // PRIVATE_OPS_
0 commit comments