LCOV - code coverage report
Current view: top level - capy - when_all.hpp (source / functions) Coverage Total Hit
Test: coverage_remapped.info Lines: 98.0 % 98 96
Test Date: 2026-02-10 18:54:58 Functions: 88.2 % 560 494

            Line data    Source code
       1              : //
       2              : // Copyright (c) 2026 Steve Gerbino
       3              : //
       4              : // Distributed under the Boost Software License, Version 1.0. (See accompanying
       5              : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
       6              : //
       7              : // Official repository: https://github.com/cppalliance/capy
       8              : //
       9              : 
      10              : #ifndef BOOST_CAPY_WHEN_ALL_HPP
      11              : #define BOOST_CAPY_WHEN_ALL_HPP
      12              : 
      13              : #include <boost/capy/detail/config.hpp>
      14              : #include <boost/capy/concept/executor.hpp>
      15              : #include <boost/capy/concept/io_awaitable.hpp>
      16              : #include <boost/capy/coro.hpp>
      17              : #include <boost/capy/ex/executor_ref.hpp>
      18              : #include <boost/capy/ex/frame_allocator.hpp>
      19              : #include <boost/capy/task.hpp>
      20              : 
      21              : #include <array>
      22              : #include <atomic>
      23              : #include <exception>
      24              : #include <optional>
      25              : #include <stop_token>
      26              : #include <tuple>
      27              : #include <type_traits>
      28              : #include <utility>
      29              : 
      30              : namespace boost {
      31              : namespace capy {
      32              : 
      33              : namespace detail {
      34              : 
      35              : /** Type trait to filter void types from a tuple.
      36              : 
      37              :     Void-returning tasks do not contribute a value to the result tuple.
      38              :     This trait computes the filtered result type.
      39              : 
      40              :     Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
      41              : */
      42              : template<typename T>
      43              : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
      44              : 
      45              : template<typename... Ts>
      46              : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
      47              : 
      48              : /** Holds the result of a single task within when_all.
      49              : */
      50              : template<typename T>
      51              : struct result_holder
      52              : {
      53              :     std::optional<T> value_;
      54              : 
      55           59 :     void set(T v)
      56              :     {
      57           59 :         value_ = std::move(v);
      58           59 :     }
      59              : 
      60           52 :     T get() &&
      61              :     {
      62           52 :         return std::move(*value_);
      63              :     }
      64              : };
      65              : 
      66              : /** Specialization for void tasks - no value storage needed.
      67              : */
      68              : template<>
      69              : struct result_holder<void>
      70              : {
      71              : };
      72              : 
      73              : /** Shared state for when_all operation.
      74              : 
      75              :     @tparam Ts The result types of the tasks.
      76              : */
      77              : template<typename... Ts>
      78              : struct when_all_state
      79              : {
      80              :     static constexpr std::size_t task_count = sizeof...(Ts);
      81              : 
      82              :     // Completion tracking - when_all waits for all children
      83              :     std::atomic<std::size_t> remaining_count_;
      84              : 
      85              :     // Result storage in input order
      86              :     std::tuple<result_holder<Ts>...> results_;
      87              : 
      88              :     // Runner handles - destroyed in await_resume while allocator is valid
      89              :     std::array<coro, task_count> runner_handles_{};
      90              : 
      91              :     // Exception storage - first error wins, others discarded
      92              :     std::atomic<bool> has_exception_{false};
      93              :     std::exception_ptr first_exception_;
      94              : 
      95              :     // Stop propagation - on error, request stop for siblings
      96              :     std::stop_source stop_source_;
      97              : 
      98              :     // Connects parent's stop_token to our stop_source
      99              :     struct stop_callback_fn
     100              :     {
     101              :         std::stop_source* source_;
     102            4 :         void operator()() const { source_->request_stop(); }
     103              :     };
     104              :     using stop_callback_t = std::stop_callback<stop_callback_fn>;
     105              :     std::optional<stop_callback_t> parent_stop_callback_;
     106              : 
     107              :     // Parent resumption
     108              :     coro continuation_;
     109              :     executor_ref caller_ex_;
     110              : 
     111           33 :     when_all_state()
     112           33 :         : remaining_count_(task_count)
     113              :     {
     114           33 :     }
     115              : 
     116              :     // Runners self-destruct in final_suspend. No destruction needed here.
     117              : 
     118              :     /** Capture an exception (first one wins).
     119              :     */
     120           11 :     void capture_exception(std::exception_ptr ep)
     121              :     {
     122           11 :         bool expected = false;
     123           11 :         if(has_exception_.compare_exchange_strong(
     124              :             expected, true, std::memory_order_relaxed))
     125            8 :             first_exception_ = ep;
     126           11 :     }
     127              : 
     128              : };
     129              : 
     130              : /** Wrapper coroutine that intercepts task completion.
     131              : 
     132              :     This runner awaits its assigned task and stores the result in
     133              :     the shared state, or captures the exception and requests stop.
     134              : */
     135              : template<typename T, typename... Ts>
     136              : struct when_all_runner
     137              : {
     138              :     struct promise_type // : frame_allocating_base  // DISABLED FOR TESTING
     139              :     {
     140              :         when_all_state<Ts...>* state_ = nullptr;
     141              :         executor_ref ex_;
     142              :         std::stop_token stop_token_;
     143              : 
     144           78 :         when_all_runner get_return_object()
     145              :         {
     146           78 :             return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
     147              :         }
     148              : 
     149           78 :         std::suspend_always initial_suspend() noexcept
     150              :         {
     151           78 :             return {};
     152              :         }
     153              : 
     154           78 :         auto final_suspend() noexcept
     155              :         {
     156              :             struct awaiter
     157              :             {
     158              :                 promise_type* p_;
     159              : 
     160           78 :                 bool await_ready() const noexcept
     161              :                 {
     162           78 :                     return false;
     163              :                 }
     164              : 
     165           78 :                 void await_suspend(coro h) noexcept
     166              :                 {
     167              :                     // Extract everything needed for signaling before
     168              :                     // self-destruction. Inline dispatch may destroy
     169              :                     // when_all_state, so we can't access members after.
     170           78 :                     auto* state = p_->state_;
     171           78 :                     auto* counter = &state->remaining_count_;
     172           78 :                     auto caller_ex = state->caller_ex_;
     173           78 :                     auto cont = state->continuation_;
     174              : 
     175              :                     // Self-destruct first - state no longer destroys runners
     176           78 :                     h.destroy();
     177              : 
     178              :                     // Signal completion. If last, dispatch parent.
     179              :                     // Uses only local copies - safe even if state
     180              :                     // is destroyed during inline dispatch.
     181           78 :                     auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
     182           78 :                     if(remaining == 1)
     183           33 :                         caller_ex.dispatch(cont);
     184           78 :                 }
     185              : 
     186            0 :                 void await_resume() const noexcept
     187              :                 {
     188            0 :                 }
     189              :             };
     190           78 :             return awaiter{this};
     191              :         }
     192              : 
     193           67 :         void return_void()
     194              :         {
     195           67 :         }
     196              : 
     197           11 :         void unhandled_exception()
     198              :         {
     199           11 :             state_->capture_exception(std::current_exception());
     200              :             // Request stop for sibling tasks
     201           11 :             state_->stop_source_.request_stop();
     202           11 :         }
     203              : 
     204              :         template<class Awaitable>
     205              :         struct transform_awaiter
     206              :         {
     207              :             std::decay_t<Awaitable> a_;
     208              :             promise_type* p_;
     209              : 
     210           78 :             bool await_ready()
     211              :             {
     212           78 :                 return a_.await_ready();
     213              :             }
     214              : 
     215           78 :             decltype(auto) await_resume()
     216              :             {
     217           78 :                 return a_.await_resume();
     218              :             }
     219              : 
     220              :             template<class Promise>
     221           77 :             auto await_suspend(std::coroutine_handle<Promise> h)
     222              :             {
     223           77 :                 return a_.await_suspend(h, p_->ex_, p_->stop_token_);
     224              :             }
     225              :         };
     226              : 
     227              :         template<class Awaitable>
     228           78 :         auto await_transform(Awaitable&& a)
     229              :         {
     230              :             using A = std::decay_t<Awaitable>;
     231              :             if constexpr (IoAwaitable<A>)
     232              :             {
     233              :                 return transform_awaiter<Awaitable>{
     234          156 :                     std::forward<Awaitable>(a), this};
     235              :             }
     236              :             else
     237              :             {
     238              :                 static_assert(sizeof(A) == 0, "requires IoAwaitable");
     239              :             }
     240           78 :         }
     241              :     };
     242              : 
     243              :     std::coroutine_handle<promise_type> h_;
     244              : 
     245           78 :     explicit when_all_runner(std::coroutine_handle<promise_type> h)
     246           78 :         : h_(h)
     247              :     {
     248           78 :     }
     249              : 
     250              :     // Enable move for all clang versions - some versions need it
     251              :     when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
     252              : 
     253              :     // Non-copyable
     254              :     when_all_runner(when_all_runner const&) = delete;
     255              :     when_all_runner& operator=(when_all_runner const&) = delete;
     256              :     when_all_runner& operator=(when_all_runner&&) = delete;
     257              : 
     258           78 :     auto release() noexcept
     259              :     {
     260           78 :         return std::exchange(h_, nullptr);
     261              :     }
     262              : };
     263              : 
     264              : /** Create a runner coroutine for a single awaitable.
     265              : 
     266              :     Awaitable is passed directly to ensure proper coroutine frame storage.
     267              : */
     268              : template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
     269              : when_all_runner<awaitable_result_t<Awaitable>, Ts...>
     270           78 : make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
     271              : {
     272              :     using T = awaitable_result_t<Awaitable>;
     273              :     if constexpr (std::is_void_v<T>)
     274              :     {
     275              :         co_await std::move(inner);
     276              :     }
     277              :     else
     278              :     {
     279              :         std::get<Index>(state->results_).set(co_await std::move(inner));
     280              :     }
     281          156 : }
     282              : 
     283              : /** Internal awaitable that launches all runner coroutines and waits.
     284              : 
     285              :     This awaitable is used inside the when_all coroutine to handle
     286              :     the concurrent execution of child awaitables.
     287              : */
     288              : template<IoAwaitable... Awaitables>
     289              : class when_all_launcher
     290              : {
     291              :     using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
     292              : 
     293              :     std::tuple<Awaitables...>* awaitables_;
     294              :     state_type* state_;
     295              : 
     296              : public:
     297           33 :     when_all_launcher(
     298              :         std::tuple<Awaitables...>* awaitables,
     299              :         state_type* state)
     300           33 :         : awaitables_(awaitables)
     301           33 :         , state_(state)
     302              :     {
     303           33 :     }
     304              : 
     305           33 :     bool await_ready() const noexcept
     306              :     {
     307           33 :         return sizeof...(Awaitables) == 0;
     308              :     }
     309              : 
     310           33 :     coro await_suspend(coro continuation, executor_ref const& caller_ex, std::stop_token const& parent_token = {})
     311              :     {
     312           33 :         state_->continuation_ = continuation;
     313           33 :         state_->caller_ex_ = caller_ex;
     314              : 
     315              :         // Forward parent's stop requests to children
     316           33 :         if(parent_token.stop_possible())
     317              :         {
     318           16 :             state_->parent_stop_callback_.emplace(
     319              :                 parent_token,
     320            8 :                 typename state_type::stop_callback_fn{&state_->stop_source_});
     321              : 
     322            8 :             if(parent_token.stop_requested())
     323            4 :                 state_->stop_source_.request_stop();
     324              :         }
     325              : 
     326              :         // CRITICAL: If the last task finishes synchronously then the parent
     327              :         // coroutine resumes, destroying its frame, and destroying this object
     328              :         // prior to the completion of await_suspend. Therefore, await_suspend
     329              :         // must ensure `this` cannot be referenced after calling `launch_one`
     330              :         // for the last time.
     331           33 :         auto token = state_->stop_source_.get_token();
     332           58 :         [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     333           33 :             (..., launch_one<Is>(caller_ex, token));
     334           33 :         }(std::index_sequence_for<Awaitables...>{});
     335              : 
     336              :         // Let signal_completion() handle resumption
     337           66 :         return std::noop_coroutine();
     338           33 :     }
     339              : 
     340           33 :     void await_resume() const noexcept
     341              :     {
     342              :         // Results are extracted by the when_all coroutine from state
     343           33 :     }
     344              : 
     345              : private:
     346              :     template<std::size_t I>
     347           78 :     void launch_one(executor_ref caller_ex, std::stop_token token)
     348              :     {
     349           78 :         auto runner = make_when_all_runner<I>(
     350           78 :             std::move(std::get<I>(*awaitables_)), state_);
     351              : 
     352           78 :         auto h = runner.release();
     353           78 :         h.promise().state_ = state_;
     354           78 :         h.promise().ex_ = caller_ex;
     355           78 :         h.promise().stop_token_ = token;
     356              : 
     357           78 :         coro ch{h};
     358           78 :         state_->runner_handles_[I] = ch;
     359           78 :         state_->caller_ex_.dispatch(ch);
     360           78 :     }
     361              : };
     362              : 
     363              : /** Compute the result type for when_all.
     364              : 
     365              :     Returns void when all tasks are void (P2300 aligned),
     366              :     otherwise returns a tuple with void types filtered out.
     367              : */
     368              : template<typename... Ts>
     369              : using when_all_result_t = std::conditional_t<
     370              :     std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
     371              :     void,
     372              :     filter_void_tuple_t<Ts...>>;
     373              : 
     374              : /** Helper to extract a single result, returning empty tuple for void.
     375              :     This is a separate function to work around a GCC-11 ICE that occurs
     376              :     when using nested immediately-invoked lambdas with pack expansion.
     377              : */
     378              : template<std::size_t I, typename... Ts>
     379           55 : auto extract_single_result(when_all_state<Ts...>& state)
     380              : {
     381              :     using T = std::tuple_element_t<I, std::tuple<Ts...>>;
     382              :     if constexpr (std::is_void_v<T>)
     383            3 :         return std::tuple<>();
     384              :     else
     385           52 :         return std::make_tuple(std::move(std::get<I>(state.results_)).get());
     386              : }
     387              : 
     388              : /** Extract results from state, filtering void types.
     389              : */
     390              : template<typename... Ts>
     391           23 : auto extract_results(when_all_state<Ts...>& state)
     392              : {
     393           42 :     return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     394           24 :         return std::tuple_cat(extract_single_result<Is>(state)...);
     395           46 :     }(std::index_sequence_for<Ts...>{});
     396              : }
     397              : 
     398              : } // namespace detail
     399              : 
     400              : /** Execute multiple awaitables concurrently and collect their results.
     401              : 
     402              :     Launches all awaitables simultaneously and waits for all to complete
     403              :     before returning. Results are collected in input order. If any
     404              :     awaitable throws, cancellation is requested for siblings and the first
     405              :     exception is rethrown after all awaitables complete.
     406              : 
     407              :     @li All child awaitables run concurrently on the caller's executor
     408              :     @li Results are returned as a tuple in input order
     409              :     @li Void-returning awaitables do not contribute to the result tuple
     410              :     @li If all awaitables return void, `when_all` returns `task<void>`
     411              :     @li First exception wins; subsequent exceptions are discarded
     412              :     @li Stop is requested for siblings on first error
     413              :     @li Completes only after all children have finished
     414              : 
     415              :     @par Thread Safety
     416              :     The returned task must be awaited from a single execution context.
     417              :     Child awaitables execute concurrently but complete through the caller's
     418              :     executor.
     419              : 
     420              :     @param awaitables The awaitables to execute concurrently. Each must
     421              :         satisfy @ref IoAwaitable and is consumed (moved-from) when
     422              :         `when_all` is awaited.
     423              : 
     424              :     @return A task yielding a tuple of non-void results. Returns
     425              :         `task<void>` when all input awaitables return void.
     426              : 
     427              :     @par Example
     428              : 
     429              :     @code
     430              :     task<> example()
     431              :     {
     432              :         // Concurrent fetch, results collected in order
     433              :         auto [user, posts] = co_await when_all(
     434              :             fetch_user( id ),      // task<User>
     435              :             fetch_posts( id )      // task<std::vector<Post>>
     436              :         );
     437              : 
     438              :         // Void awaitables don't contribute to result
     439              :         co_await when_all(
     440              :             log_event( "start" ),  // task<void>
     441              :             notify_user( id )      // task<void>
     442              :         );
     443              :         // Returns task<void>, no result tuple
     444              :     }
     445              :     @endcode
     446              : 
     447              :     @see IoAwaitable, task
     448              : */
     449              : template<IoAwaitable... As>
     450           33 : [[nodiscard]] auto when_all(As... awaitables)
     451              :     -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
     452              : {
     453              :     using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
     454              : 
     455              :     // State is stored in the coroutine frame, using the frame allocator
     456              :     detail::when_all_state<detail::awaitable_result_t<As>...> state;
     457              : 
     458              :     // Store awaitables in the frame
     459              :     std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
     460              : 
     461              :     // Launch all awaitables and wait for completion
     462              :     co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
     463              : 
     464              :     // Propagate first exception if any.
     465              :     // Safe without explicit acquire: capture_exception() is sequenced-before
     466              :     // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
     467              :     // last task's decrement that resumes this coroutine.
     468              :     if(state.first_exception_)
     469              :         std::rethrow_exception(state.first_exception_);
     470              : 
     471              :     // Extract and return results
     472              :     if constexpr (std::is_void_v<result_type>)
     473              :         co_return;
     474              :     else
     475              :         co_return detail::extract_results(state);
     476           66 : }
     477              : 
     478              : /// Compute the result type of `when_all` for the given task types.
     479              : template<typename... Ts>
     480              : using when_all_result_type = detail::when_all_result_t<Ts...>;
     481              : 
     482              : } // namespace capy
     483              : } // namespace boost
     484              : 
     485              : #endif
        

Generated by: LCOV version 2.3