Skip to content

Commit 93712a7

Browse files
author
zichao.zhang
committed
Avoid sleep-based block waiting in ShmQueue
1 parent 7c0fd85 commit 93712a7

File tree

2 files changed

+86
-44
lines changed

2 files changed

+86
-44
lines changed

graphlearn_torch/csrc/shm_queue.cc

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ void ShmQueueMeta::Initialize(size_t max_block_num, size_t max_buf_size) {
6363
read_block_id_ = 0;
6464
alloc_offset_ = 0;
6565
released_offset_ = 0;
66-
sem_init(&alloc_lock_, 1, 1);
67-
sem_init(&release_lock_, 1, 1);
66+
pthread_mutexattr_t mutex_attr;
67+
pthread_mutexattr_init(&mutex_attr);
68+
pthread_mutexattr_setpshared(&mutex_attr, PTHREAD_PROCESS_SHARED);
69+
pthread_mutex_init(&mutex_, &mutex_attr);
70+
pthread_condattr_t cond_attr;
71+
pthread_condattr_init(&cond_attr);
72+
pthread_condattr_setpshared(&cond_attr, PTHREAD_PROCESS_SHARED);
73+
pthread_cond_init(&alloc_cond_, &cond_attr);
74+
pthread_cond_init(&release_cond_, &cond_attr);
6875
for (size_t i = 0; i < max_block_num_; i++) {
6976
GetBlockMeta(i).Initialize();
7077
}
@@ -74,16 +81,17 @@ void ShmQueueMeta::Finalize() {
7481
for (size_t i = 0; i < max_block_num_; i++) {
7582
GetBlockMeta(i).Finalize();
7683
}
77-
sem_destroy(&alloc_lock_);
78-
sem_destroy(&release_lock_);
84+
pthread_mutex_destroy(&mutex_);
85+
pthread_cond_destroy(&alloc_cond_);
86+
pthread_cond_destroy(&release_cond_);
7987
}
8088

8189
size_t ShmQueueMeta::GetBlockToWrite(size_t size,
8290
size_t* begin_offset,
8391
size_t* data_offset,
8492
size_t* end_offset) {
85-
sem_wait(&alloc_lock_);
86-
auto id = write_block_id_++;
93+
pthread_mutex_lock(&mutex_);
94+
size_t id = write_block_id_++;
8795
auto ring_offset = alloc_offset_ % max_buf_size_;
8896
auto tail_frag_size = max_buf_size_ - ring_offset;
8997
*begin_offset = alloc_offset_;
@@ -94,29 +102,78 @@ size_t ShmQueueMeta::GetBlockToWrite(size_t size,
94102
alloc_offset_ += size;
95103
*end_offset = alloc_offset_;
96104
Check(*end_offset - *begin_offset < max_buf_size_, "message is too large!");
97-
sem_post(&alloc_lock_);
105+
pthread_mutex_unlock(&mutex_);
106+
107+
// Notify one reader thread
108+
pthread_cond_signal(&alloc_cond_);
109+
110+
// Wait until no conflict
111+
pthread_mutex_lock(&mutex_);
112+
auto condition = [this, id, end_offset] {
113+
return (id < read_block_id_ + max_block_num_) &&
114+
(*end_offset < released_offset_ + max_buf_size_);
115+
};
116+
while (!condition()) {
117+
pthread_cond_wait(&release_cond_, &mutex_);
118+
}
119+
pthread_mutex_unlock(&mutex_);
120+
98121
return id;
99122
}
100123

101-
size_t ShmQueueMeta::GetBlockToRead() {
102-
return __sync_fetch_and_add(&read_block_id_, 1);
124+
size_t ShmQueueMeta::GetBlockToRead(uint32_t timeout_ms) {
125+
auto condition = [this] {
126+
if (read_block_id_ >= write_block_id_) {
127+
return false;
128+
}
129+
return true;
130+
};
131+
pthread_mutex_lock(&mutex_);
132+
if (timeout_ms == 0) {
133+
while (!condition()) {
134+
pthread_cond_wait(&alloc_cond_, &mutex_);
135+
}
136+
} else {
137+
struct timespec until {};
138+
clock_gettime(CLOCK_REALTIME, &until);
139+
until.tv_sec += timeout_ms / 1000;
140+
until.tv_nsec += (timeout_ms % 1000) * 1000000;
141+
while (!condition()) {
142+
int ret = pthread_cond_timedwait(&alloc_cond_, &mutex_, &until);
143+
if (ret == ETIMEDOUT) {
144+
throw QueueTimeoutError();
145+
}
146+
}
147+
}
148+
auto id = read_block_id_++;
149+
pthread_mutex_unlock(&mutex_);
150+
151+
// Notify all waiting writer thread
152+
pthread_cond_broadcast(&release_cond_);
153+
154+
return id;
103155
}
104156

105157
void ShmQueueMeta::ReleaseBlock(size_t id) {
106-
sem_wait(&release_lock_);
158+
pthread_mutex_lock(&mutex_);
107159
GetBlockMeta(id).release = true;
160+
bool release_some = false;
108161
while (id < read_block_id_) {
109162
auto& block = GetBlockMeta(id);
110163
if (block.release && block.begin == released_offset_) {
111164
released_offset_ = block.end;
112165
block.release = false;
113166
block.NotifyToWrite();
167+
release_some = true;
114168
} else {
115169
break;
116170
}
117171
id++;
118172
}
119-
sem_post(&release_lock_);
173+
pthread_mutex_unlock(&mutex_);
174+
if (release_some) {
175+
pthread_cond_broadcast(&release_cond_);
176+
}
120177
}
121178

122179
ShmQueueMeta::BlockMeta& ShmQueueMeta::GetBlockMeta(size_t id) {
@@ -192,13 +249,8 @@ void ShmQueue::Enqueue(const void* data, size_t size) {
192249

193250
void ShmQueue::Enqueue(size_t size, WriteFunc func) {
194251
size_t begin_offset, data_offset, end_offset;
195-
auto block_id = meta_->GetBlockToWrite(
196-
size, &begin_offset, &data_offset, &end_offset);
197-
// Check for ring buffer conflicts.
198-
while (block_id >= meta_->read_block_id_ + max_block_num_ ||
199-
end_offset >= meta_->released_offset_ + max_buf_size_) {
200-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
201-
}
252+
auto block_id =
253+
meta_->GetBlockToWrite(size, &begin_offset, &data_offset, &end_offset);
202254

203255
auto& block = meta_->GetBlockMeta(block_id);
204256
block.WaitForWriting();
@@ -213,19 +265,8 @@ void ShmQueue::Enqueue(size_t size, WriteFunc func) {
213265
block.NotifyToRead();
214266
}
215267

216-
ShmData ShmQueue::Dequeue(unsigned int timeout_ms) {
217-
auto timeout_duration = std::chrono::milliseconds(timeout_ms);
218-
auto start_time = std::chrono::steady_clock::now();
219-
while (meta_->read_block_id_ >= meta_->write_block_id_) {
220-
if (timeout_ms > 0) {
221-
auto elapsed_time = std::chrono::steady_clock::now() - start_time;
222-
if (elapsed_time > timeout_duration) {
223-
throw QueueTimeoutError();
224-
}
225-
}
226-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
227-
}
228-
auto block_id = meta_->GetBlockToRead();
268+
ShmData ShmQueue::Dequeue(uint32_t timeout_ms) {
269+
auto block_id = meta_->GetBlockToRead(timeout_ms);
229270

230271
auto& block = meta_->GetBlockMeta(block_id);
231272
block.WaitForReading();

graphlearn_torch/include/shm_queue.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class ShmQueueMeta {
139139

140140
/// Get a block to read.
141141
/// \return block id
142-
size_t GetBlockToRead();
142+
size_t GetBlockToRead(uint32_t timeout_ms = 0);
143143

144144
/// Release a block with block id.
145145
void ReleaseBlock(size_t id);
@@ -152,17 +152,18 @@ class ShmQueueMeta {
152152
void* GetData(size_t offset);
153153

154154
private:
155-
size_t max_block_num_;
156-
size_t max_buf_size_;
157-
size_t block_meta_offset_;
158-
size_t data_buf_offset_;
159-
size_t write_block_id_;
160-
size_t read_block_id_;
161-
size_t alloc_offset_;
162-
size_t released_offset_;
163-
sem_t alloc_lock_;
164-
sem_t release_lock_;
165-
friend class ShmQueue;
155+
size_t max_block_num_;
156+
size_t max_buf_size_;
157+
size_t block_meta_offset_;
158+
size_t data_buf_offset_;
159+
size_t write_block_id_;
160+
size_t read_block_id_;
161+
size_t alloc_offset_;
162+
size_t released_offset_;
163+
pthread_mutex_t mutex_;
164+
pthread_cond_t alloc_cond_;
165+
pthread_cond_t release_cond_;
166+
friend class ShmQueue;
166167
};
167168

168169
/// Shared-Memory Queue should be constructed and destructed on main process.
@@ -206,7 +207,7 @@ class ShmQueue {
206207

207208
/// Dequeue a message on child process.
208209
/// \return `ShmData`
209-
ShmData Dequeue(unsigned int timeout_ms = 0);
210+
ShmData Dequeue(uint32_t timeout_ms = 0);
210211

211212
bool Empty();
212213

0 commit comments

Comments
 (0)