From 0ad4df848524138ed70eca407c4134a5cb8201b3 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Wed, 25 Sep 2024 12:30:30 -0700 Subject: [PATCH 1/3] implement gloo abort --- gloo/common/CMakeLists.txt | 1 + gloo/common/error.cc | 46 ++++++++++++++++++++++++++++ gloo/common/error.h | 6 ++++ gloo/transport/tcp/unbound_buffer.cc | 16 ++++++++-- gloo/transport/uv/unbound_buffer.cc | 26 ++++++++++++---- 5 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 gloo/common/error.cc diff --git a/gloo/common/CMakeLists.txt b/gloo/common/CMakeLists.txt index 307588a89..4b8e4c5d2 100644 --- a/gloo/common/CMakeLists.txt +++ b/gloo/common/CMakeLists.txt @@ -1,6 +1,7 @@ set(GLOO_COMMON_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/logging.cc" "${CMAKE_CURRENT_SOURCE_DIR}/utils.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/error.cc" ) set(GLOO_COMMON_HDRS diff --git a/gloo/common/error.cc b/gloo/common/error.cc new file mode 100644 index 000000000..c14f38b4a --- /dev/null +++ b/gloo/common/error.cc @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "gloo/common/error.h" + +namespace gloo { + + +std::list _cvs; +std::mutex _cvs_mutex; + +std::atomic_bool _is_aborted_flag(false); + +bool _is_aborted() { + return _is_aborted_flag.load(); +} + +void abort() { + _is_aborted_flag.store(true); + std::lock_guard guard(_cvs_mutex); + for(auto& cv : _cvs) { + if(cv != NULL) { + cv->notify_all(); + } + } + GLOO_THROW("GLOO ABORTED"); +} + +void _register_cv(std::condition_variable *cv) { + std::lock_guard guard(_cvs_mutex); + _cvs.push_back(cv); +} + +void _deregister_cv(std::condition_variable *cv) { + std::lock_guard guard(_cvs_mutex); + _cvs.remove(cv); +} +} // namespace gloo diff --git a/gloo/common/error.h b/gloo/common/error.h index 4eac45ec8..c7e98fa46 100644 --- a/gloo/common/error.h +++ b/gloo/common/error.h @@ -10,6 +10,7 @@ #include #include +#include #include "gloo/common/string.h" @@ -20,6 +21,11 @@ namespace gloo { const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero(); +bool _is_aborted(); +void abort(); +void _register_cv(std::condition_variable *cv); +void _deregister_cv(std::condition_variable *cv); + // A base class for all gloo runtime errors struct Exception : public std::runtime_error { Exception() = delete; diff --git a/gloo/transport/tcp/unbound_buffer.cc b/gloo/transport/tcp/unbound_buffer.cc index f6545db7a..f03639622 100644 --- a/gloo/transport/tcp/unbound_buffer.cc +++ b/gloo/transport/tcp/unbound_buffer.cc @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer( recvRank_(-1), sendCompletions_(0), sendRank_(-1), - shareableNonOwningPtr_(this) {} + shareableNonOwningPtr_(this) { + gloo::_register_cv(&recvCv_); + gloo::_register_cv(&sendCv_); +} -UnboundBuffer::~UnboundBuffer() {} +UnboundBuffer::~UnboundBuffer() { + gloo::_deregister_cv(&recvCv_); + gloo::_deregister_cv(&sendCv_); +} void UnboundBuffer::handleRecvCompletion(int rank) { std::lock_guard lock(m_); @@ -60,6 +66,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) { if (recvCompletions_ == 0) { auto done = recvCv_.wait_for(lock, timeout, [&] { throwIfException(); + if(gloo::_is_aborted()) { + abortWaitRecv_ = true; + } return abortWaitRecv_ || recvCompletions_ > 0; }); if (!done) { @@ -111,6 +120,9 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) { if (sendCompletions_ == 0) { auto done = sendCv_.wait_for(lock, timeout, [&] { throwIfException(); + if(gloo::_is_aborted()) { + abortWaitSend_ = true; + } return abortWaitSend_ || sendCompletions_ > 0; }); if (!done) { diff --git a/gloo/transport/uv/unbound_buffer.cc b/gloo/transport/uv/unbound_buffer.cc index bc9ba1c97..858d89417 100644 --- a/gloo/transport/uv/unbound_buffer.cc +++ b/gloo/transport/uv/unbound_buffer.cc @@ -28,9 +28,15 @@ UnboundBuffer::UnboundBuffer( recvRank_(-1), sendCompletions_(0), sendRank_(-1), - shareableNonOwningPtr_(this) {} + shareableNonOwningPtr_(this) { + gloo::_register_cv(&recvCv_); + gloo::_register_cv(&sendCv_); +} -UnboundBuffer::~UnboundBuffer() {} +UnboundBuffer::~UnboundBuffer() { + gloo::_deregister_cv(&recvCv_); + gloo::_deregister_cv(&sendCv_); +} void UnboundBuffer::handleRecvCompletion(int rank) { std::lock_guard lock(mutex_); @@ -58,8 +64,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) { } if (recvCompletions_ == 0) { - auto done = recvCv_.wait_for( - lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; }); + auto done = recvCv_.wait_for(lock, timeout, [&] { + if(gloo::_is_aborted()) { + abortWaitRecv_ = true; + } + return abortWaitRecv_ || recvCompletions_ > 0; + }); if (!done) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting ", @@ -94,8 +104,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) { } if (sendCompletions_ == 0) { - auto done = sendCv_.wait_for( - lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; }); + auto done = sendCv_.wait_for(lock, timeout, [&] { + if(gloo::_is_aborted()) { + abortWaitSend_ = true; + } + return abortWaitSend_ || sendCompletions_ > 0; + }); if (!done) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting ", From a7a49d18383af95051cf9d128be6ff0f79e4fe92 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Tue, 5 Nov 2024 17:25:44 +0400 Subject: [PATCH 2/3] add test --- gloo/test/CMakeLists.txt | 1 + gloo/test/abort_test.cc | 145 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 gloo/test/abort_test.cc diff --git a/gloo/test/CMakeLists.txt b/gloo/test/CMakeLists.txt index 743e089ee..a0e060457 100644 --- a/gloo/test/CMakeLists.txt +++ b/gloo/test/CMakeLists.txt @@ -1,6 +1,7 @@ find_package(OpenSSL 1.1 REQUIRED EXACT) set(GLOO_TEST_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc" diff --git a/gloo/test/abort_test.cc b/gloo/test/abort_test.cc new file mode 100644 index 000000000..aec1ae35f --- /dev/null +++ b/gloo/test/abort_test.cc @@ -0,0 +1,145 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include "gloo/barrier_all_to_all.h" +#include "gloo/barrier_all_to_one.h" +#include "gloo/broadcast.h" +#include "gloo/test/base_test.h" + +namespace gloo { +namespace test { +namespace { + +// Function to instantiate and run algorithm. +using Func = void(std::shared_ptr<::gloo::Context>); + +// Test parameterization. +using Param = std::tuple>; + +// Test fixture. +class BarrierTest : public BaseTest, + public ::testing::WithParamInterface {}; + +TEST_P(BarrierTest, SinglePointer) { + const auto transport = std::get<0>(GetParam()); + const auto contextSize = std::get<1>(GetParam()); + const auto fn = std::get<2>(GetParam()); + + spawn(transport, contextSize, [&](std::shared_ptr context) { + fn(context); + }); +} + +static std::function barrierAllToAll = + [](std::shared_ptr<::gloo::Context> context) { + ::gloo::BarrierAllToAll algorithm(context); + algorithm.run(); + }; + +INSTANTIATE_TEST_CASE_P( + BarrierAllToAll, + BarrierTest, + ::testing::Combine( + ::testing::ValuesIn(kTransportsForClassAlgorithms), + ::testing::Range(2, 16), + ::testing::Values(barrierAllToAll))); + +static std::function barrierAllToOne = + [](std::shared_ptr<::gloo::Context> context) { + ::gloo::BarrierAllToOne algorithm(context); + algorithm.run(); + }; + +INSTANTIATE_TEST_CASE_P( + BarrierAllToOne, + BarrierTest, + ::testing::Combine( + ::testing::ValuesIn(kTransportsForClassAlgorithms), + ::testing::Range(2, 16), + ::testing::Values(barrierAllToOne))); + +// Synchronized version of std::chrono::clock::now(). +// All processes participating in the specified context will +// see the same value. +template +std::chrono::time_point syncNow(std::shared_ptr context) { + const typename clock::time_point now = clock::now(); + typename clock::duration::rep count = now.time_since_epoch().count(); + BroadcastOptions opts(context); + opts.setRoot(0); + opts.setOutput(&count, 1); + broadcast(opts); + return typename clock::time_point(typename clock::duration(count)); +} + +using NewParam = std::tuple; + +class BarrierNewTest : public BaseTest, + public ::testing::WithParamInterface {}; + +TEST_P(BarrierNewTest, Default) { + const auto transport = std::get<0>(GetParam()); + const auto contextSize = std::get<1>(GetParam()); + + spawn(transport, contextSize, [&](std::shared_ptr context) { + BarrierOptions opts(context); + + // Run barrier to synchronize processes after starting. + barrier(opts); + + // Take turns in sleeping for a bit and checking that all processes + // saw that artificial delay through the barrier. + auto singleProcessDelay = std::chrono::milliseconds(1000); + for (size_t i = 0; i < context->size; i++) { + const auto start = syncNow(context); + if (i == context->rank) { + /* sleep override */ + std::this_thread::sleep_for(singleProcessDelay); + } + + barrier(opts); + abort(); + + // Expect all processes to have taken less than the sleep, as abort was called + auto stop = std::chrono::high_resolution_clock::now(); + auto delta = std::chrono::duration_cast( + stop - start); + ASSERT_LE(delta.count(), singleProcessDelay.count()); + } + }); +} + +INSTANTIATE_TEST_CASE_P( + BarrierNewDefault, + BarrierNewTest, + ::testing::Combine( + ::testing::ValuesIn(kTransportsForFunctionAlgorithms), + ::testing::Values(1, 2, 4, 7))); + +TEST_F(BarrierNewTest, TestTimeout) { + spawn(Transport::TCP, 2, [&](std::shared_ptr context) { + BarrierOptions opts(context); + opts.setTimeout(std::chrono::milliseconds(10)); + if (context->rank == 0) { + try { + barrier(opts); + FAIL() << "Expected exception to be thrown"; + } catch (::gloo::IoException& e) { + ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos); + } + } + }); +} + +} // namespace +} // namespace test +} // namespace gloo From 563b8e1f115ac31413e765e43eb6a6bcd0ba152f Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Thu, 7 Nov 2024 14:08:49 +0400 Subject: [PATCH 3/3] update test --- gloo/test/abort_test.cc | 109 ++++++++-------------------------------- 1 file changed, 22 insertions(+), 87 deletions(-) diff --git a/gloo/test/abort_test.cc b/gloo/test/abort_test.cc index aec1ae35f..da8c4ac6e 100644 --- a/gloo/test/abort_test.cc +++ b/gloo/test/abort_test.cc @@ -19,54 +19,6 @@ namespace gloo { namespace test { namespace { -// Function to instantiate and run algorithm. -using Func = void(std::shared_ptr<::gloo::Context>); - -// Test parameterization. -using Param = std::tuple>; - -// Test fixture. -class BarrierTest : public BaseTest, - public ::testing::WithParamInterface {}; - -TEST_P(BarrierTest, SinglePointer) { - const auto transport = std::get<0>(GetParam()); - const auto contextSize = std::get<1>(GetParam()); - const auto fn = std::get<2>(GetParam()); - - spawn(transport, contextSize, [&](std::shared_ptr context) { - fn(context); - }); -} - -static std::function barrierAllToAll = - [](std::shared_ptr<::gloo::Context> context) { - ::gloo::BarrierAllToAll algorithm(context); - algorithm.run(); - }; - -INSTANTIATE_TEST_CASE_P( - BarrierAllToAll, - BarrierTest, - ::testing::Combine( - ::testing::ValuesIn(kTransportsForClassAlgorithms), - ::testing::Range(2, 16), - ::testing::Values(barrierAllToAll))); - -static std::function barrierAllToOne = - [](std::shared_ptr<::gloo::Context> context) { - ::gloo::BarrierAllToOne algorithm(context); - algorithm.run(); - }; - -INSTANTIATE_TEST_CASE_P( - BarrierAllToOne, - BarrierTest, - ::testing::Combine( - ::testing::ValuesIn(kTransportsForClassAlgorithms), - ::testing::Range(2, 16), - ::testing::Values(barrierAllToOne))); - // Synchronized version of std::chrono::clock::now(). // All processes participating in the specified context will // see the same value. @@ -83,10 +35,10 @@ std::chrono::time_point syncNow(std::shared_ptr context) { using NewParam = std::tuple; -class BarrierNewTest : public BaseTest, - public ::testing::WithParamInterface {}; +class AbortBarrierTest : public BaseTest, + public ::testing::WithParamInterface {}; -TEST_P(BarrierNewTest, Default) { +TEST_P(AbortBarrierTest, Default) { const auto transport = std::get<0>(GetParam()); const auto contextSize = std::get<1>(GetParam()); @@ -96,49 +48,32 @@ TEST_P(BarrierNewTest, Default) { // Run barrier to synchronize processes after starting. barrier(opts); - // Take turns in sleeping for a bit and checking that all processes - // saw that artificial delay through the barrier. - auto singleProcessDelay = std::chrono::milliseconds(1000); - for (size_t i = 0; i < context->size; i++) { - const auto start = syncNow(context); - if (i == context->rank) { - /* sleep override */ - std::this_thread::sleep_for(singleProcessDelay); - } - + auto timeout = std::chrono::milliseconds(context->getTimeout()); + const auto start = syncNow(context); + // Run barrier on all ranks but 0 so it hangs + if (context->rank != 0) { barrier(opts); - abort(); + } - // Expect all processes to have taken less than the sleep, as abort was called - auto stop = std::chrono::high_resolution_clock::now(); - auto delta = std::chrono::duration_cast( - stop - start); - ASSERT_LE(delta.count(), singleProcessDelay.count()); + // Abort should unhang the barrier + try { + abort(); + } catch (const Exception &e) { + EXPECT_TRUE(strstr(e.what(), "GLOO ABORTED") != NULL); } + + // Expect all processes to have taken less than the timeout, as abort was + // called + auto stop = std::chrono::high_resolution_clock::now(); + auto delta = std::chrono::duration_cast(stop - start); + ASSERT_LE(delta.count(), timeout.count() / 4); }); } INSTANTIATE_TEST_CASE_P( - BarrierNewDefault, - BarrierNewTest, - ::testing::Combine( - ::testing::ValuesIn(kTransportsForFunctionAlgorithms), - ::testing::Values(1, 2, 4, 7))); - -TEST_F(BarrierNewTest, TestTimeout) { - spawn(Transport::TCP, 2, [&](std::shared_ptr context) { - BarrierOptions opts(context); - opts.setTimeout(std::chrono::milliseconds(10)); - if (context->rank == 0) { - try { - barrier(opts); - FAIL() << "Expected exception to be thrown"; - } catch (::gloo::IoException& e) { - ASSERT_NE(std::string(e.what()).find("Timed out"), std::string::npos); - } - } - }); -} + AbortBarrier, AbortBarrierTest, + ::testing::Combine(::testing::ValuesIn(kTransportsForFunctionAlgorithms), + ::testing::Values(1, 2, 4, 7))); } // namespace } // namespace test