diff --git a/graphlearn_torch/csrc/shm_queue.cc b/graphlearn_torch/csrc/shm_queue.cc index 11302e7f..57f17b12 100644 --- a/graphlearn_torch/csrc/shm_queue.cc +++ b/graphlearn_torch/csrc/shm_queue.cc @@ -63,8 +63,15 @@ void ShmQueueMeta::Initialize(size_t max_block_num, size_t max_buf_size) { read_block_id_ = 0; alloc_offset_ = 0; released_offset_ = 0; - sem_init(&alloc_lock_, 1, 1); - sem_init(&release_lock_, 1, 1); + pthread_mutexattr_t mutex_attr; + pthread_mutexattr_init(&mutex_attr); + pthread_mutexattr_setpshared(&mutex_attr, PTHREAD_PROCESS_SHARED); + pthread_mutex_init(&mutex_, &mutex_attr); + pthread_condattr_t cond_attr; + pthread_condattr_init(&cond_attr); + pthread_condattr_setpshared(&cond_attr, PTHREAD_PROCESS_SHARED); + pthread_cond_init(&alloc_cond_, &cond_attr); + pthread_cond_init(&release_cond_, &cond_attr); for (size_t i = 0; i < max_block_num_; i++) { GetBlockMeta(i).Initialize(); } @@ -74,16 +81,17 @@ void ShmQueueMeta::Finalize() { for (size_t i = 0; i < max_block_num_; i++) { GetBlockMeta(i).Finalize(); } - sem_destroy(&alloc_lock_); - sem_destroy(&release_lock_); + pthread_mutex_destroy(&mutex_); + pthread_cond_destroy(&alloc_cond_); + pthread_cond_destroy(&release_cond_); } size_t ShmQueueMeta::GetBlockToWrite(size_t size, size_t* begin_offset, size_t* data_offset, size_t* end_offset) { - sem_wait(&alloc_lock_); - auto id = write_block_id_++; + pthread_mutex_lock(&mutex_); + size_t id = write_block_id_++; auto ring_offset = alloc_offset_ % max_buf_size_; auto tail_frag_size = max_buf_size_ - ring_offset; *begin_offset = alloc_offset_; @@ -94,29 +102,78 @@ size_t ShmQueueMeta::GetBlockToWrite(size_t size, alloc_offset_ += size; *end_offset = alloc_offset_; Check(*end_offset - *begin_offset < max_buf_size_, "message is too large!"); - sem_post(&alloc_lock_); + pthread_mutex_unlock(&mutex_); + + // Notify one reader thread + pthread_cond_signal(&alloc_cond_); + + // Wait until no conflict + pthread_mutex_lock(&mutex_); + auto condition = [this, id, end_offset] { + return (id < read_block_id_ + max_block_num_) && + (*end_offset < released_offset_ + max_buf_size_); + }; + while (!condition()) { + pthread_cond_wait(&release_cond_, &mutex_); + } + pthread_mutex_unlock(&mutex_); + return id; } -size_t ShmQueueMeta::GetBlockToRead() { - return __sync_fetch_and_add(&read_block_id_, 1); +size_t ShmQueueMeta::GetBlockToRead(uint32_t timeout_ms) { + auto condition = [this] { + if (read_block_id_ >= write_block_id_) { + return false; + } + return true; + }; + pthread_mutex_lock(&mutex_); + if (timeout_ms == 0) { + while (!condition()) { + pthread_cond_wait(&alloc_cond_, &mutex_); + } + } else { + struct timespec until {}; + clock_gettime(CLOCK_REALTIME, &until); + until.tv_sec += timeout_ms / 1000; + until.tv_nsec += (timeout_ms % 1000) * 1000000; + while (!condition()) { + int ret = pthread_cond_timedwait(&alloc_cond_, &mutex_, &until); + if (ret == ETIMEDOUT) { + throw QueueTimeoutError(); + } + } + } + auto id = read_block_id_++; + pthread_mutex_unlock(&mutex_); + + // Notify all waiting writer thread + pthread_cond_broadcast(&release_cond_); + + return id; } void ShmQueueMeta::ReleaseBlock(size_t id) { - sem_wait(&release_lock_); + pthread_mutex_lock(&mutex_); GetBlockMeta(id).release = true; + bool release_some = false; while (id < read_block_id_) { auto& block = GetBlockMeta(id); if (block.release && block.begin == released_offset_) { released_offset_ = block.end; block.release = false; block.NotifyToWrite(); + release_some = true; } else { break; } id++; } - sem_post(&release_lock_); + pthread_mutex_unlock(&mutex_); + if (release_some) { + pthread_cond_broadcast(&release_cond_); + } } ShmQueueMeta::BlockMeta& ShmQueueMeta::GetBlockMeta(size_t id) { @@ -192,13 +249,8 @@ void ShmQueue::Enqueue(const void* data, size_t size) { void ShmQueue::Enqueue(size_t size, WriteFunc func) { size_t begin_offset, data_offset, end_offset; - auto block_id = meta_->GetBlockToWrite( - size, &begin_offset, &data_offset, &end_offset); - // Check for ring buffer conflicts. - while (block_id >= meta_->read_block_id_ + max_block_num_ || - end_offset >= meta_->released_offset_ + max_buf_size_) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } + auto block_id = + meta_->GetBlockToWrite(size, &begin_offset, &data_offset, &end_offset); auto& block = meta_->GetBlockMeta(block_id); block.WaitForWriting(); @@ -213,19 +265,8 @@ void ShmQueue::Enqueue(size_t size, WriteFunc func) { block.NotifyToRead(); } -ShmData ShmQueue::Dequeue(unsigned int timeout_ms) { - auto timeout_duration = std::chrono::milliseconds(timeout_ms); - auto start_time = std::chrono::steady_clock::now(); - while (meta_->read_block_id_ >= meta_->write_block_id_) { - if (timeout_ms > 0) { - auto elapsed_time = std::chrono::steady_clock::now() - start_time; - if (elapsed_time > timeout_duration) { - throw QueueTimeoutError(); - } - } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - auto block_id = meta_->GetBlockToRead(); +ShmData ShmQueue::Dequeue(uint32_t timeout_ms) { + auto block_id = meta_->GetBlockToRead(timeout_ms); auto& block = meta_->GetBlockMeta(block_id); block.WaitForReading(); diff --git a/graphlearn_torch/include/shm_queue.h b/graphlearn_torch/include/shm_queue.h index 5eefe971..40fa431f 100644 --- a/graphlearn_torch/include/shm_queue.h +++ b/graphlearn_torch/include/shm_queue.h @@ -139,7 +139,7 @@ class ShmQueueMeta { /// Get a block to read. /// \return block id - size_t GetBlockToRead(); + size_t GetBlockToRead(uint32_t timeout_ms = 0); /// Release a block with block id. void ReleaseBlock(size_t id); @@ -152,17 +152,18 @@ class ShmQueueMeta { void* GetData(size_t offset); private: - size_t max_block_num_; - size_t max_buf_size_; - size_t block_meta_offset_; - size_t data_buf_offset_; - size_t write_block_id_; - size_t read_block_id_; - size_t alloc_offset_; - size_t released_offset_; - sem_t alloc_lock_; - sem_t release_lock_; - friend class ShmQueue; + size_t max_block_num_; + size_t max_buf_size_; + size_t block_meta_offset_; + size_t data_buf_offset_; + size_t write_block_id_; + size_t read_block_id_; + size_t alloc_offset_; + size_t released_offset_; + pthread_mutex_t mutex_; + pthread_cond_t alloc_cond_; + pthread_cond_t release_cond_; + friend class ShmQueue; }; /// Shared-Memory Queue should be constructed and destructed on main process. @@ -206,7 +207,7 @@ class ShmQueue { /// Dequeue a message on child process. /// \return `ShmData` - ShmData Dequeue(unsigned int timeout_ms = 0); + ShmData Dequeue(uint32_t timeout_ms = 0); bool Empty();