libs/capy/include/boost/capy/when_all.hpp

96.9% Lines (95/98) 90.5% Functions (431/476) 100.0% Branches (22/22)
libs/capy/include/boost/capy/when_all.hpp
Line Branch Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_awaitable.hpp>
16 #include <boost/capy/coro.hpp>
17 #include <boost/capy/ex/executor_ref.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 59 void set(T v)
56 {
57 59 value_ = std::move(v);
58 59 }
59
60 52 T get() &&
61 {
62 52 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<coro, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 4 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 coro continuation_;
109 executor_ref caller_ex_;
110
111 33 when_all_state()
112
1/1
✓ Branch 5 taken 33 times.
33 : remaining_count_(task_count)
113 {
114 33 }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 11 void capture_exception(std::exception_ptr ep)
121 {
122 11 bool expected = false;
123
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
11 if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 8 first_exception_ = ep;
126 11 }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 executor_ref ex_;
142 std::stop_token stop_token_;
143
144 78 when_all_runner get_return_object()
145 {
146 78 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
147 }
148
149 78 std::suspend_always initial_suspend() noexcept
150 {
151 78 return {};
152 }
153
154 78 auto final_suspend() noexcept
155 {
156 struct awaiter
157 {
158 promise_type* p_;
159
160 8 bool await_ready() const noexcept
161 {
162 8 return false;
163 }
164
165 8 void await_suspend(coro h) noexcept
166 {
167 // Extract everything needed for signaling before
168 // self-destruction. Inline dispatch may destroy
169 // when_all_state, so we can't access members after.
170 8 auto* state = p_->state_;
171 8 auto* counter = &state->remaining_count_;
172 8 auto caller_ex = state->caller_ex_;
173 8 auto cont = state->continuation_;
174
175 // Self-destruct first - state no longer destroys runners
176 8 h.destroy();
177
178 // Signal completion. If last, dispatch parent.
179 // Uses only local copies - safe even if state
180 // is destroyed during inline dispatch.
181 8 auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
182
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
8 if(remaining == 1)
183 4 caller_ex.dispatch(cont);
184 8 }
185
186 void await_resume() const noexcept
187 {
188 }
189 };
190 78 return awaiter{this};
191 }
192
193 67 void return_void()
194 {
195 67 }
196
197 11 void unhandled_exception()
198 {
199 11 state_->capture_exception(std::current_exception());
200 // Request stop for sibling tasks
201 11 state_->stop_source_.request_stop();
202 11 }
203
204 template<class Awaitable>
205 struct transform_awaiter
206 {
207 std::decay_t<Awaitable> a_;
208 promise_type* p_;
209
210 78 bool await_ready()
211 {
212 78 return a_.await_ready();
213 }
214
215 78 decltype(auto) await_resume()
216 {
217 78 return a_.await_resume();
218 }
219
220 template<class Promise>
221 77 auto await_suspend(std::coroutine_handle<Promise> h)
222 {
223 77 return a_.await_suspend(h, p_->ex_, p_->stop_token_);
224 }
225 };
226
227 template<class Awaitable>
228 78 auto await_transform(Awaitable&& a)
229 {
230 using A = std::decay_t<Awaitable>;
231 if constexpr (IoAwaitable<A>)
232 {
233 return transform_awaiter<Awaitable>{
234 156 std::forward<Awaitable>(a), this};
235 }
236 else
237 {
238 static_assert(sizeof(A) == 0, "requires IoAwaitable");
239 }
240 78 }
241 };
242
243 std::coroutine_handle<promise_type> h_;
244
245 78 explicit when_all_runner(std::coroutine_handle<promise_type> h)
246 78 : h_(h)
247 {
248 78 }
249
250 // Enable move for all clang versions - some versions need it
251 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
252
253 // Non-copyable
254 when_all_runner(when_all_runner const&) = delete;
255 when_all_runner& operator=(when_all_runner const&) = delete;
256 when_all_runner& operator=(when_all_runner&&) = delete;
257
258 78 auto release() noexcept
259 {
260 78 return std::exchange(h_, nullptr);
261 }
262 };
263
264 /** Create a runner coroutine for a single awaitable.
265
266 Awaitable is passed directly to ensure proper coroutine frame storage.
267 */
268 template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
269 when_all_runner<awaitable_result_t<Awaitable>, Ts...>
270
1/1
✓ Branch 1 taken 78 times.
78 make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
271 {
272 using T = awaitable_result_t<Awaitable>;
273 if constexpr (std::is_void_v<T>)
274 {
275 co_await std::move(inner);
276 }
277 else
278 {
279 std::get<Index>(state->results_).set(co_await std::move(inner));
280 }
281 156 }
282
283 /** Internal awaitable that launches all runner coroutines and waits.
284
285 This awaitable is used inside the when_all coroutine to handle
286 the concurrent execution of child awaitables.
287 */
288 template<IoAwaitable... Awaitables>
289 class when_all_launcher
290 {
291 using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
292
293 std::tuple<Awaitables...>* awaitables_;
294 state_type* state_;
295
296 public:
297 33 when_all_launcher(
298 std::tuple<Awaitables...>* awaitables,
299 state_type* state)
300 33 : awaitables_(awaitables)
301 33 , state_(state)
302 {
303 33 }
304
305 33 bool await_ready() const noexcept
306 {
307 33 return sizeof...(Awaitables) == 0;
308 }
309
310 33 coro await_suspend(coro continuation, executor_ref const& caller_ex, std::stop_token const& parent_token = {})
311 {
312 33 state_->continuation_ = continuation;
313 33 state_->caller_ex_ = caller_ex;
314
315 // Forward parent's stop requests to children
316
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 25 times.
33 if(parent_token.stop_possible())
317 {
318 16 state_->parent_stop_callback_.emplace(
319 parent_token,
320 8 typename state_type::stop_callback_fn{&state_->stop_source_});
321
322
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 if(parent_token.stop_requested())
323 4 state_->stop_source_.request_stop();
324 }
325
326 // CRITICAL: If the last task finishes synchronously then the parent
327 // coroutine resumes, destroying its frame, and destroying this object
328 // prior to the completion of await_suspend. Therefore, await_suspend
329 // must ensure `this` cannot be referenced after calling `launch_one`
330 // for the last time.
331 33 auto token = state_->stop_source_.get_token();
332 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
333
2/2
✓ Branch 2 taken 4 times.
✓ Branch 6 taken 4 times.
4 (..., launch_one<Is>(caller_ex, token));
334
2/2
✓ Branch 1 taken 29 times.
✓ Branch 1 taken 4 times.
33 }(std::index_sequence_for<Awaitables...>{});
335
336 // Let signal_completion() handle resumption
337 66 return std::noop_coroutine();
338 33 }
339
340 33 void await_resume() const noexcept
341 {
342 // Results are extracted by the when_all coroutine from state
343 33 }
344
345 private:
346 template<std::size_t I>
347 78 void launch_one(executor_ref caller_ex, std::stop_token token)
348 {
349
1/1
✓ Branch 2 taken 78 times.
78 auto runner = make_when_all_runner<I>(
350 78 std::move(std::get<I>(*awaitables_)), state_);
351
352 78 auto h = runner.release();
353 78 h.promise().state_ = state_;
354 78 h.promise().ex_ = caller_ex;
355 78 h.promise().stop_token_ = token;
356
357 78 coro ch{h};
358 78 state_->runner_handles_[I] = ch;
359
1/1
✓ Branch 1 taken 78 times.
78 state_->caller_ex_.dispatch(ch);
360 78 }
361 };
362
363 /** Compute the result type for when_all.
364
365 Returns void when all tasks are void (P2300 aligned),
366 otherwise returns a tuple with void types filtered out.
367 */
368 template<typename... Ts>
369 using when_all_result_t = std::conditional_t<
370 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
371 void,
372 filter_void_tuple_t<Ts...>>;
373
374 /** Helper to extract a single result, returning empty tuple for void.
375 This is a separate function to work around a GCC-11 ICE that occurs
376 when using nested immediately-invoked lambdas with pack expansion.
377 */
378 template<std::size_t I, typename... Ts>
379 55 auto extract_single_result(when_all_state<Ts...>& state)
380 {
381 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
382 if constexpr (std::is_void_v<T>)
383 3 return std::tuple<>();
384 else
385
1/1
✓ Branch 4 taken 52 times.
52 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
386 }
387
388 /** Extract results from state, filtering void types.
389 */
390 template<typename... Ts>
391 23 auto extract_results(when_all_state<Ts...>& state)
392 {
393 23 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
394
3/3
✓ Branch 1 taken 4 times.
✓ Branch 4 taken 4 times.
✓ Branch 7 taken 4 times.
4 return std::tuple_cat(extract_single_result<Is>(state)...);
395
1/1
✓ Branch 1 taken 23 times.
46 }(std::index_sequence_for<Ts...>{});
396 }
397
398 } // namespace detail
399
400 /** Execute multiple awaitables concurrently and collect their results.
401
402 Launches all awaitables simultaneously and waits for all to complete
403 before returning. Results are collected in input order. If any
404 awaitable throws, cancellation is requested for siblings and the first
405 exception is rethrown after all awaitables complete.
406
407 @li All child awaitables run concurrently on the caller's executor
408 @li Results are returned as a tuple in input order
409 @li Void-returning awaitables do not contribute to the result tuple
410 @li If all awaitables return void, `when_all` returns `task<void>`
411 @li First exception wins; subsequent exceptions are discarded
412 @li Stop is requested for siblings on first error
413 @li Completes only after all children have finished
414
415 @par Thread Safety
416 The returned task must be awaited from a single execution context.
417 Child awaitables execute concurrently but complete through the caller's
418 executor.
419
420 @param awaitables The awaitables to execute concurrently. Each must
421 satisfy @ref IoAwaitable and is consumed (moved-from) when
422 `when_all` is awaited.
423
424 @return A task yielding a tuple of non-void results. Returns
425 `task<void>` when all input awaitables return void.
426
427 @par Example
428
429 @code
430 task<> example()
431 {
432 // Concurrent fetch, results collected in order
433 auto [user, posts] = co_await when_all(
434 fetch_user( id ), // task<User>
435 fetch_posts( id ) // task<std::vector<Post>>
436 );
437
438 // Void awaitables don't contribute to result
439 co_await when_all(
440 log_event( "start" ), // task<void>
441 notify_user( id ) // task<void>
442 );
443 // Returns task<void>, no result tuple
444 }
445 @endcode
446
447 @see IoAwaitable, task
448 */
449 template<IoAwaitable... As>
450
1/1
✓ Branch 1 taken 33 times.
33 [[nodiscard]] auto when_all(As... awaitables)
451 -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
452 {
453 using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
454
455 // State is stored in the coroutine frame, using the frame allocator
456 detail::when_all_state<detail::awaitable_result_t<As>...> state;
457
458 // Store awaitables in the frame
459 std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
460
461 // Launch all awaitables and wait for completion
462 co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
463
464 // Propagate first exception if any.
465 // Safe without explicit acquire: capture_exception() is sequenced-before
466 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
467 // last task's decrement that resumes this coroutine.
468 if(state.first_exception_)
469 std::rethrow_exception(state.first_exception_);
470
471 // Extract and return results
472 if constexpr (std::is_void_v<result_type>)
473 co_return;
474 else
475 co_return detail::extract_results(state);
476 66 }
477
478 /// Compute the result type of `when_all` for the given task types.
479 template<typename... Ts>
480 using when_all_result_type = detail::when_all_result_t<Ts...>;
481
482 } // namespace capy
483 } // namespace boost
484
485 #endif
486