diff --git a/packages/react-native-audio-api/common/cpp/audioapi/HostObjects/utils/AudioDecoderHostObject.cpp b/packages/react-native-audio-api/common/cpp/audioapi/HostObjects/utils/AudioDecoderHostObject.cpp index 050d24a2c..f660f76d8 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/HostObjects/utils/AudioDecoderHostObject.cpp +++ b/packages/react-native-audio-api/common/cpp/audioapi/HostObjects/utils/AudioDecoderHostObject.cpp @@ -32,29 +32,30 @@ JSI_HOST_FUNCTION_IMPL(AudioDecoderHostObject, decodeWithMemoryBlock) { auto sampleRate = args[1].getNumber(); - auto promise = promiseVendor_->createPromise( - [data, size, sampleRate](std::shared_ptr promise) { - std::thread([data, size, sampleRate, promise = std::move(promise)]() { - auto result = - AudioDecoder::decodeWithMemoryBlock(data, size, sampleRate); - - if (!result) { - promise->reject("Failed to decode audio data."); - return; - } - - auto audioBufferHostObject = - std::make_shared(result); - - promise->resolve([audioBufferHostObject = std::move( - audioBufferHostObject)](jsi::Runtime &runtime) { - auto jsiObject = jsi::Object::createFromHostObject( - runtime, audioBufferHostObject); - jsiObject.setExternalMemoryPressure( - runtime, audioBufferHostObject->getSizeInBytes()); - return jsiObject; - }); - }).detach(); + auto promise = promiseVendor_->createAsyncPromise( + [data, size, sampleRate]() -> PromiseResolver { + auto result = + AudioDecoder::decodeWithMemoryBlock(data, size, sampleRate); + + if (!result) { + return [](jsi::Runtime &runtime) + -> std::variant { + return std::string("Failed to decode audio data."); + }; + } + + auto audioBufferHostObject = + std::make_shared(result); + + return [audioBufferHostObject = + std::move(audioBufferHostObject)](jsi::Runtime &runtime) + -> std::variant { + auto jsiObject = + jsi::Object::createFromHostObject(runtime, audioBufferHostObject); + jsiObject.setExternalMemoryPressure( + runtime, audioBufferHostObject->getSizeInBytes()); + return jsiObject; + }; }); return promise; } @@ -63,29 +64,29 @@ JSI_HOST_FUNCTION_IMPL(AudioDecoderHostObject, decodeWithFilePath) { auto sourcePath = args[0].getString(runtime).utf8(runtime); auto sampleRate = args[1].getNumber(); - auto promise = promiseVendor_->createPromise( - [sourcePath, sampleRate](std::shared_ptr promise) { - std::thread([sourcePath, sampleRate, promise = std::move(promise)]() { - auto result = - AudioDecoder::decodeWithFilePath(sourcePath, sampleRate); - - if (!result) { - promise->reject("Failed to decode audio data source."); - return; - } - - auto audioBufferHostObject = - std::make_shared(result); - - promise->resolve([audioBufferHostObject = std::move( - audioBufferHostObject)](jsi::Runtime &runtime) { - auto jsiObject = jsi::Object::createFromHostObject( - runtime, audioBufferHostObject); - jsiObject.setExternalMemoryPressure( - runtime, audioBufferHostObject->getSizeInBytes()); - return jsiObject; - }); - }).detach(); + auto promise = promiseVendor_->createAsyncPromise( + [sourcePath, sampleRate]() -> PromiseResolver { + auto result = AudioDecoder::decodeWithFilePath(sourcePath, sampleRate); + + if (!result) { + return [](jsi::Runtime &runtime) + -> std::variant { + return std::string("Failed to decode audio data source."); + }; + } + + auto audioBufferHostObject = + std::make_shared(result); + + return [audioBufferHostObject = + std::move(audioBufferHostObject)](jsi::Runtime &runtime) + -> std::variant { + auto jsiObject = + jsi::Object::createFromHostObject(runtime, audioBufferHostObject); + jsiObject.setExternalMemoryPressure( + runtime, audioBufferHostObject->getSizeInBytes()); + return jsiObject; + }; }); return promise; @@ -97,34 +98,31 @@ JSI_HOST_FUNCTION_IMPL(AudioDecoderHostObject, decodeWithPCMInBase64) { auto inputChannelCount = args[2].getNumber(); auto interleaved = args[3].getBool(); - auto promise = promiseVendor_->createPromise( - [b64, inputSampleRate, inputChannelCount, interleaved]( - std::shared_ptr promise) { - std::thread([b64, - inputSampleRate, - inputChannelCount, - interleaved, - promise = std::move(promise)]() { - auto result = AudioDecoder::decodeWithPCMInBase64( - b64, inputSampleRate, inputChannelCount, interleaved); - - if (!result) { - promise->reject("Failed to decode audio data source."); - return; - } - - auto audioBufferHostObject = - std::make_shared(result); - - promise->resolve([audioBufferHostObject = std::move( - audioBufferHostObject)](jsi::Runtime &runtime) { - auto jsiObject = jsi::Object::createFromHostObject( - runtime, audioBufferHostObject); - jsiObject.setExternalMemoryPressure( - runtime, audioBufferHostObject->getSizeInBytes()); - return jsiObject; - }); - }).detach(); + auto promise = promiseVendor_->createAsyncPromise( + [b64, inputSampleRate, inputChannelCount, interleaved]() + -> PromiseResolver { + auto result = AudioDecoder::decodeWithPCMInBase64( + b64, inputSampleRate, inputChannelCount, interleaved); + + if (!result) { + return [](jsi::Runtime &runtime) + -> std::variant { + return std::string("Failed to decode audio data source."); + }; + } + + auto audioBufferHostObject = + std::make_shared(result); + + return [audioBufferHostObject = + std::move(audioBufferHostObject)](jsi::Runtime &runtime) + -> std::variant { + auto jsiObject = + jsi::Object::createFromHostObject(runtime, audioBufferHostObject); + jsiObject.setExternalMemoryPressure( + runtime, audioBufferHostObject->getSizeInBytes()); + return jsiObject; + }; }); return promise; diff --git a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp index 6fa25a498..6d78f8df4 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp +++ b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp @@ -62,8 +62,7 @@ jsi::Value PromiseVendor::createPromise( } jsi::Value PromiseVendor::createAsyncPromise( - std::function(jsi::Runtime &)> - &&function) { + std::function &&function) { auto &runtime = *runtime_; auto callInvoker = callInvoker_; auto threadPool = threadPool_; @@ -74,32 +73,18 @@ jsi::Value PromiseVendor::createAsyncPromise( jsi::Runtime &runtime, const jsi::Value &thisValue, const jsi::Value *arguments, - size_t count) -> jsi::Value { + size_t count) mutable -> jsi::Value { auto resolveLocal = arguments[0].asObject(runtime).asFunction(runtime); auto resolve = std::make_shared(std::move(resolveLocal)); auto rejectLocal = arguments[1].asObject(runtime).asFunction(runtime); auto reject = std::make_shared(std::move(rejectLocal)); - threadPool->schedule([callInvoker = std::move(callInvoker), - function = std::move(function), - resolve = std::move(resolve), - reject = std::move(reject), - &runtime]() { - auto result = function(runtime); - if (std::holds_alternative(result)) { - auto valueShared = std::make_shared( - std::move(std::get(result))); - callInvoker->invokeAsync([resolve, &runtime, valueShared]() -> void { - resolve->call(runtime, *valueShared); - }); - } else { - auto errorMessage = std::get(result); - callInvoker->invokeAsync([reject, &runtime, errorMessage]() -> void { - auto error = jsi::JSError(runtime, errorMessage); - reject->call(runtime, error.value()); - }); - } - }); + threadPool->schedule( + &PromiseVendor::asyncPromiseJob, + std::move(callInvoker), + std::move(function), + std::move(resolve), + std::move(reject)); return jsi::Value::undefined(); }; auto promiseFunction = jsi::Function::createFromHostFunction( @@ -110,4 +95,27 @@ jsi::Value PromiseVendor::createAsyncPromise( return promiseCtor.callAsConstructor(runtime, std::move(promiseFunction)); } +void PromiseVendor::asyncPromiseJob( + std::shared_ptr callInvoker, + std::function &&function, + std::shared_ptr &&resolve, + std::shared_ptr &&reject) { + auto resolver = function(); + callInvoker->invokeAsync( + [resolver = std::move(resolver), + reject = std::move(reject), + resolve = std::move(resolve)](jsi::Runtime &runtime) -> void { + auto result = resolver(runtime); + if (std::holds_alternative(result)) { + auto valueShared = std::make_shared( + std::move(std::get(result))); + resolve->call(runtime, *valueShared); + } else { + auto errorMessage = std::get(result); + auto error = jsi::JSError(runtime, errorMessage); + reject->call(runtime, error.value()); + } + }); +} + } // namespace audioapi diff --git a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h index 985bc64dd..82e4a6029 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h +++ b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h @@ -33,6 +33,8 @@ class Promise { std::function reject_; }; +using PromiseResolver = std::function(jsi::Runtime&)>; + class PromiseVendor { public: PromiseVendor(jsi::Runtime *runtime, const std::shared_ptr &callInvoker): @@ -51,20 +53,30 @@ class PromiseVendor { /// @example /// ```cpp /// auto promise = promiseVendor_->createAsyncPromise( - /// [](jsi::Runtime& rt) -> std::variant { - /// // Simulate some heavy work + /// []() -> PromiseResolver { + /// // Simulate some heavy work on a background thread /// std::this_thread::sleep_for(std::chrono::seconds(2)); - /// return jsi::String::createFromUtf8(rt, "Promise resolved successfully!"); + /// return [](jsi::Runtime &rt) -> std::variant { + /// // Prepare and return the result on javascript thread + /// return jsi::String::createFromUtf8(rt, "Promise resolved successfully!"); + /// }; /// } /// ); /// /// return promise; - jsi::Value createAsyncPromise(std::function(jsi::Runtime&)> &&function); + jsi::Value createAsyncPromise(std::function &&function); private: jsi::Runtime *runtime_; std::shared_ptr callInvoker_; std::shared_ptr threadPool_; + + static void asyncPromiseJob( + std::shared_ptr callInvoker, + std::function &&function, + std::shared_ptr &&resolve, + std::shared_ptr &&reject + ); }; } // namespace audioapi diff --git a/packages/react-native-audio-api/common/cpp/audioapi/utils/MoveOnlyFunction.hpp b/packages/react-native-audio-api/common/cpp/audioapi/utils/MoveOnlyFunction.hpp new file mode 100644 index 000000000..571afbaee --- /dev/null +++ b/packages/react-native-audio-api/common/cpp/audioapi/utils/MoveOnlyFunction.hpp @@ -0,0 +1,91 @@ +#pragma once +#include + +namespace audioapi { + +/// @brief A forward declaration of a move-only function wrapper. +/// @note it is somehow required to have specialization below +/// @tparam Signature +template +class move_only_function; // Forward declaration + +/// @brief A move-only function wrapper similar to std::function but non-copyable. +/// @details This class allows you to store and invoke callable objects (like lambdas, function pointers, or functors) +/// that can be moved but not copied. It is useful for scenarios where you want to ensure that the callable +/// is unique and cannot be duplicated. +/// @note This implementation uses type erasure to store the callable object. +/// @note The callable object must be invocable with the specified arguments and return type. +/// @note IMPORTANT: This thing is implemented in C++23 standard and can be replaced with std::move_only_function once we switch to C++23. +/// @tparam R +/// @tparam ...Args +template +class move_only_function { + /// @brief The base class for type erasure. + /// @note It gets optimized by Empty Base Optimization (EBO) when possible. + struct callable_base { + virtual ~callable_base() = default; + virtual R operator()(Args... args) = 0; + }; + + /// @brief The implementation of the callable object. + /// @tparam F + template + struct callable_impl : callable_base { + /// @brief The stored callable object. + F f; + + /// @brief Construct a new callable_impl object. + /// @tparam G + /// @param func + /// @note The enable_if_t ensures that F can be constructed from G&&. + template>> + callable_impl(G&& func) : f(std::forward(func)) {} + + /// @brief Invoke the stored callable object with the given arguments. + /// @param args + /// @return R + /// @note The if constexpr is used to handle the case when R is void. + inline R operator()(Args... args) override { + if constexpr (std::is_void_v) { + /// To avoid "warning: expression result unused" when R is void + f(std::forward(args)...); + } else { + return f(std::forward(args)...); + } + } + }; + + /// @brief The unique pointer to the base callable type. + std::unique_ptr impl_; + +public: + move_only_function() = default; + move_only_function(std::nullptr_t) noexcept : impl_(nullptr) {} + + template + move_only_function(F&& f) + : impl_(std::make_unique>>(std::forward(f))) {} + + move_only_function(const move_only_function&) = delete; + move_only_function& operator=(const move_only_function&) = delete; + + move_only_function(move_only_function&&) = default; + move_only_function& operator=(move_only_function&&) = default; + + inline explicit operator bool() const noexcept { + return impl_ != nullptr; + } + + inline R operator()(Args... args) { + /// We are unlikely to hit this case as we want to optimize for the common case. + if (impl_ == nullptr) [[ unlikely ]] { + throw std::bad_function_call{}; + } + return (*impl_)(std::forward(args)...); + } + + void swap(move_only_function& other) noexcept { + impl_.swap(other.impl_); + } +}; +} // namespace audioapi diff --git a/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp b/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp index 5fec2ff19..811d12392 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp +++ b/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp @@ -3,6 +3,8 @@ #include #include #include + +#include #include namespace audioapi { @@ -15,7 +17,7 @@ namespace audioapi { /// @note IMPORTANT: ThreadPool is not thread-safe and events should be scheduled from a single thread only. class ThreadPool { struct StopEvent {}; - struct TaskEvent { std::function task; }; + struct TaskEvent { audioapi::move_only_function task; }; using Event = std::variant; using Sender = channels::spsc::Sender; @@ -46,14 +48,21 @@ class ThreadPool { } /// @brief Schedule a task to be executed by the thread pool - /// @param task The task to be executed + /// @tparam Func The type of the task function + /// @tparam Args The types of the task function arguments + /// @param task The task function to be executed + /// @param args The arguments to be passed to the task function /// @note This function is lock-free and most of the time wait-free, but may block if the load balancer queue is full. - /// @note Please remember that the task will be executed in a different thread, so make sure to capture any required variables by value. + /// @note Please remember that the task will be executed in a different thread, so make sure to pass any required variables by value or with std::move. /// @note The task should not throw exceptions, as they will not be caught. /// @note The task should end at some point, otherwise the thread pool will never be able to shut down. /// @note IMPORTANT: This function is not thread-safe and should be called from a single thread only. - void schedule(std::function &&task) noexcept { - loadBalancerSender.send(TaskEvent{std::move(task)}); + template>> + void schedule(Func &&task, Args &&... args) noexcept { + auto boundTask = [f = std::forward(task), ...capturedArgs = std::forward(args)]() mutable { + f(std::forward(capturedArgs)...); + }; + loadBalancerSender.send(TaskEvent{audioapi::move_only_function(std::move(boundTask))}); } private: