Skip to content

Avoid sleep-based blocking wait in ShmQueue #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 72 additions & 31 deletions graphlearn_torch/csrc/shm_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,15 @@ void ShmQueueMeta::Initialize(size_t max_block_num, size_t max_buf_size) {
read_block_id_ = 0;
alloc_offset_ = 0;
released_offset_ = 0;
sem_init(&alloc_lock_, 1, 1);
sem_init(&release_lock_, 1, 1);
pthread_mutexattr_t mutex_attr;
pthread_mutexattr_init(&mutex_attr);
pthread_mutexattr_setpshared(&mutex_attr, PTHREAD_PROCESS_SHARED);
pthread_mutex_init(&mutex_, &mutex_attr);
pthread_condattr_t cond_attr;
pthread_condattr_init(&cond_attr);
pthread_condattr_setpshared(&cond_attr, PTHREAD_PROCESS_SHARED);
pthread_cond_init(&alloc_cond_, &cond_attr);
pthread_cond_init(&release_cond_, &cond_attr);
for (size_t i = 0; i < max_block_num_; i++) {
GetBlockMeta(i).Initialize();
}
Expand All @@ -74,16 +81,17 @@ void ShmQueueMeta::Finalize() {
for (size_t i = 0; i < max_block_num_; i++) {
GetBlockMeta(i).Finalize();
}
sem_destroy(&alloc_lock_);
sem_destroy(&release_lock_);
pthread_mutex_destroy(&mutex_);
pthread_cond_destroy(&alloc_cond_);
pthread_cond_destroy(&release_cond_);
}

size_t ShmQueueMeta::GetBlockToWrite(size_t size,
size_t* begin_offset,
size_t* data_offset,
size_t* end_offset) {
sem_wait(&alloc_lock_);
auto id = write_block_id_++;
pthread_mutex_lock(&mutex_);
size_t id = write_block_id_++;
auto ring_offset = alloc_offset_ % max_buf_size_;
auto tail_frag_size = max_buf_size_ - ring_offset;
*begin_offset = alloc_offset_;
Expand All @@ -94,29 +102,78 @@ size_t ShmQueueMeta::GetBlockToWrite(size_t size,
alloc_offset_ += size;
*end_offset = alloc_offset_;
Check(*end_offset - *begin_offset < max_buf_size_, "message is too large!");
sem_post(&alloc_lock_);
pthread_mutex_unlock(&mutex_);

// Notify one reader thread
pthread_cond_signal(&alloc_cond_);

// Wait until no conflict
pthread_mutex_lock(&mutex_);
auto condition = [this, id, end_offset] {
return (id < read_block_id_ + max_block_num_) &&
(*end_offset < released_offset_ + max_buf_size_);
};
while (!condition()) {
pthread_cond_wait(&release_cond_, &mutex_);
}
pthread_mutex_unlock(&mutex_);

return id;
}

size_t ShmQueueMeta::GetBlockToRead() {
return __sync_fetch_and_add(&read_block_id_, 1);
size_t ShmQueueMeta::GetBlockToRead(uint32_t timeout_ms) {
auto condition = [this] {
if (read_block_id_ >= write_block_id_) {
return false;
}
return true;
};
pthread_mutex_lock(&mutex_);
if (timeout_ms == 0) {
while (!condition()) {
pthread_cond_wait(&alloc_cond_, &mutex_);
}
} else {
struct timespec until {};
clock_gettime(CLOCK_REALTIME, &until);
until.tv_sec += timeout_ms / 1000;
until.tv_nsec += (timeout_ms % 1000) * 1000000;
while (!condition()) {
int ret = pthread_cond_timedwait(&alloc_cond_, &mutex_, &until);
if (ret == ETIMEDOUT) {
throw QueueTimeoutError();
}
}
}
auto id = read_block_id_++;
pthread_mutex_unlock(&mutex_);

// Notify all waiting writer thread
pthread_cond_broadcast(&release_cond_);

return id;
}

void ShmQueueMeta::ReleaseBlock(size_t id) {
sem_wait(&release_lock_);
pthread_mutex_lock(&mutex_);
GetBlockMeta(id).release = true;
bool release_some = false;
while (id < read_block_id_) {
auto& block = GetBlockMeta(id);
if (block.release && block.begin == released_offset_) {
released_offset_ = block.end;
block.release = false;
block.NotifyToWrite();
release_some = true;
} else {
break;
}
id++;
}
sem_post(&release_lock_);
pthread_mutex_unlock(&mutex_);
if (release_some) {
pthread_cond_broadcast(&release_cond_);
}
}

ShmQueueMeta::BlockMeta& ShmQueueMeta::GetBlockMeta(size_t id) {
Expand Down Expand Up @@ -192,13 +249,8 @@ void ShmQueue::Enqueue(const void* data, size_t size) {

void ShmQueue::Enqueue(size_t size, WriteFunc func) {
size_t begin_offset, data_offset, end_offset;
auto block_id = meta_->GetBlockToWrite(
size, &begin_offset, &data_offset, &end_offset);
// Check for ring buffer conflicts.
while (block_id >= meta_->read_block_id_ + max_block_num_ ||
end_offset >= meta_->released_offset_ + max_buf_size_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
auto block_id =
meta_->GetBlockToWrite(size, &begin_offset, &data_offset, &end_offset);

auto& block = meta_->GetBlockMeta(block_id);
block.WaitForWriting();
Expand All @@ -213,19 +265,8 @@ void ShmQueue::Enqueue(size_t size, WriteFunc func) {
block.NotifyToRead();
}

ShmData ShmQueue::Dequeue(unsigned int timeout_ms) {
auto timeout_duration = std::chrono::milliseconds(timeout_ms);
auto start_time = std::chrono::steady_clock::now();
while (meta_->read_block_id_ >= meta_->write_block_id_) {
if (timeout_ms > 0) {
auto elapsed_time = std::chrono::steady_clock::now() - start_time;
if (elapsed_time > timeout_duration) {
throw QueueTimeoutError();
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
auto block_id = meta_->GetBlockToRead();
ShmData ShmQueue::Dequeue(uint32_t timeout_ms) {
auto block_id = meta_->GetBlockToRead(timeout_ms);

auto& block = meta_->GetBlockMeta(block_id);
block.WaitForReading();
Expand Down
27 changes: 14 additions & 13 deletions graphlearn_torch/include/shm_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class ShmQueueMeta {

/// Get a block to read.
/// \return block id
size_t GetBlockToRead();
size_t GetBlockToRead(uint32_t timeout_ms = 0);

/// Release a block with block id.
void ReleaseBlock(size_t id);
Expand All @@ -152,17 +152,18 @@ class ShmQueueMeta {
void* GetData(size_t offset);

private:
size_t max_block_num_;
size_t max_buf_size_;
size_t block_meta_offset_;
size_t data_buf_offset_;
size_t write_block_id_;
size_t read_block_id_;
size_t alloc_offset_;
size_t released_offset_;
sem_t alloc_lock_;
sem_t release_lock_;
friend class ShmQueue;
size_t max_block_num_;
size_t max_buf_size_;
size_t block_meta_offset_;
size_t data_buf_offset_;
size_t write_block_id_;
size_t read_block_id_;
size_t alloc_offset_;
size_t released_offset_;
pthread_mutex_t mutex_;
pthread_cond_t alloc_cond_;
pthread_cond_t release_cond_;
friend class ShmQueue;
};

/// Shared-Memory Queue should be constructed and destructed on main process.
Expand Down Expand Up @@ -206,7 +207,7 @@ class ShmQueue {

/// Dequeue a message on child process.
/// \return `ShmData`
ShmData Dequeue(unsigned int timeout_ms = 0);
ShmData Dequeue(uint32_t timeout_ms = 0);

bool Empty();

Expand Down