From 0c11c32a5a2e52df0fac478733eb7b746aadbbfb Mon Sep 17 00:00:00 2001 From: Eric Niebler Date: Thu, 6 Feb 2025 18:14:47 +0000 Subject: [PATCH] misc bug fixes for cudax ustdex --- .../cuda/experimental/__async/sender/env.cuh | 3 ++ .../cuda/experimental/__async/sender/meta.cuh | 2 + .../__async/sender/rcvr_with_env.cuh | 4 ++ .../experimental/__async/sender/sync_wait.cuh | 47 +++++++++++++------ .../experimental/__async/sender/utility.cuh | 10 ++-- .../experimental/__async/sender/write_env.cuh | 2 +- 6 files changed, 47 insertions(+), 21 deletions(-) diff --git a/cudax/include/cuda/experimental/__async/sender/env.cuh b/cudax/include/cuda/experimental/__async/sender/env.cuh index eb0e232fead..7227563de62 100644 --- a/cudax/include/cuda/experimental/__async/sender/env.cuh +++ b/cudax/include/cuda/experimental/__async/sender/env.cuh @@ -76,6 +76,9 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT prop prop& operator=(const prop&) = delete; }; +template +prop(_Query, _Value) -> prop<_Query, _Value>; + template struct _CCCL_TYPE_VISIBILITY_DEFAULT env { diff --git a/cudax/include/cuda/experimental/__async/sender/meta.cuh b/cudax/include/cuda/experimental/__async/sender/meta.cuh index bff5c71b4fb..7603360b214 100644 --- a/cudax/include/cuda/experimental/__async/sender/meta.cuh +++ b/cudax/include/cuda/experimental/__async/sender/meta.cuh @@ -59,6 +59,8 @@ struct _IN_ALGORITHM; struct _WHAT; +struct _WHY; + struct _WITH_FUNCTION; struct _WITH_SENDER; diff --git a/cudax/include/cuda/experimental/__async/sender/rcvr_with_env.cuh b/cudax/include/cuda/experimental/__async/sender/rcvr_with_env.cuh index b4d6793656f..b5138206e96 100644 --- a/cudax/include/cuda/experimental/__async/sender/rcvr_with_env.cuh +++ b/cudax/include/cuda/experimental/__async/sender/rcvr_with_env.cuh @@ -134,6 +134,10 @@ struct __rcvr_with_env_t<_Rcvr*, _Env> _Rcvr* __rcvr_; _Env __env_; }; + +template +__rcvr_with_env_t(_Rcvr, _Env) -> __rcvr_with_env_t<_Rcvr, _Env>; + } // namespace cuda::experimental::__async #include diff --git a/cudax/include/cuda/experimental/__async/sender/sync_wait.cuh b/cudax/include/cuda/experimental/__async/sender/sync_wait.cuh index ddf23694272..971abcc9433 100644 --- a/cudax/include/cuda/experimental/__async/sender/sync_wait.cuh +++ b/cudax/include/cuda/experimental/__async/sender/sync_wait.cuh @@ -26,6 +26,7 @@ // run_loop isn't supported on-device yet, so neither can sync_wait be. #if !defined(__CUDA_ARCH__) +# include # include # include @@ -33,6 +34,7 @@ # include # include # include +# include # include @@ -115,26 +117,38 @@ private: } }; - using __values_t = value_types_of_t<_Sndr, __rcvr_t, _CUDA_VSTD::tuple, _CUDA_VSTD::__type_self_t>; + using __completions_t = completion_signatures_of_t<_Sndr, __rcvr_t>; + + struct __on_success + { + using type = __value_types<__completions_t, _CUDA_VSTD::tuple, _CUDA_VSTD::__type_self_t>; + }; + + using __on_error = _CUDA_VSTD::type_identity<_CUDA_VSTD::tuple<__completions_t>>; + + using __values_t = + typename _CUDA_VSTD::_If<__is_completion_signatures<__completions_t>, __on_success, __on_error>::type; _CUDA_VSTD::optional<__values_t>* __values_; ::std::exception_ptr __eptr_; run_loop __loop_; }; - struct __invalid_sync_wait + template + struct __always_false : _CUDA_VSTD::false_type + {}; + + template + struct __bad_sync_wait { - const __invalid_sync_wait& value() const - { - return *this; - } + static_assert(__always_false<_Diagnostic>(), + "sync_wait cannot compute the completions of the sender passed to it."); + static __bad_sync_wait __result(); - const __invalid_sync_wait& operator*() const - { - return *this; - } + const __bad_sync_wait& value() const; + const __bad_sync_wait& operator*() const; - int __i_; + int i{}; // so that structured bindings kinda work }; public: @@ -168,12 +182,11 @@ public: { using __rcvr_t = typename __state_t<_Sndr>::__rcvr_t; using __values_t = typename __state_t<_Sndr>::__values_t; - using __completions = completion_signatures_of_t<_Sndr, __rcvr_t>; - static_assert(__is_completion_signatures<__completions>); + using __completions = typename __state_t<_Sndr>::__completions_t; if constexpr (!__is_completion_signatures<__completions>) { - return __invalid_sync_wait{0}; + return __bad_sync_wait<__completions>::__result(); } else { @@ -196,6 +209,12 @@ public: return __result; // uses NRVO to "return" the result } } + + template + auto operator()(_Sndr&& __sndr, _Env&& __env) const + { + return (*this)(__async::write_env(static_cast<_Sndr&&>(__sndr), static_cast<_Env&&>(__env))); + } }; _CCCL_GLOBAL_CONSTANT sync_wait_t sync_wait{}; diff --git a/cudax/include/cuda/experimental/__async/sender/utility.cuh b/cudax/include/cuda/experimental/__async/sender/utility.cuh index b82f629a0ac..b3277e7dff4 100644 --- a/cudax/include/cuda/experimental/__async/sender/utility.cuh +++ b/cudax/include/cuda/experimental/__async/sender/utility.cuh @@ -21,6 +21,8 @@ # pragma system_header #endif // no system header +#include +#include #include #include @@ -34,11 +36,7 @@ namespace cuda::experimental::__async { _CCCL_GLOBAL_CONSTANT size_t __npos = static_cast(-1); -struct __ignore -{ - template - _CUDAX_API constexpr __ignore(_As&&...) noexcept {}; -}; +using __ignore _CCCL_NODEBUG_ALIAS = _CUDA_VSTD::__ignore_t; // NOLINT: misc-unused-using-decls using _CUDA_VSTD::__undefined; // NOLINT: misc-unused-using-decls @@ -116,7 +114,7 @@ _CUDAX_API constexpr void __swap(_Ty& __left, _Ty& __right) noexcept } template -_CUDAX_API constexpr _Ty __decay_copy(_Ty&& __ty) noexcept(__nothrow_decay_copyable<_Ty>) +_CUDAX_API constexpr _CUDA_VSTD::decay_t<_Ty> __decay_copy(_Ty&& __ty) noexcept(__nothrow_decay_copyable<_Ty>) { return static_cast<_Ty&&>(__ty); } diff --git a/cudax/include/cuda/experimental/__async/sender/write_env.cuh b/cudax/include/cuda/experimental/__async/sender/write_env.cuh index 1a9d6b913a8..00d28a27e78 100644 --- a/cudax/include/cuda/experimental/__async/sender/write_env.cuh +++ b/cudax/include/cuda/experimental/__async/sender/write_env.cuh @@ -50,7 +50,7 @@ private: connect_result_t<_Sndr, __rcvr_with_env_t<_Rcvr, _Env>*> __opstate_; _CUDAX_API explicit __opstate_t(_Sndr&& __sndr, _Env __env, _Rcvr __rcvr) - : __env_rcvr_(static_cast<_Env&&>(__env), static_cast<_Rcvr&&>(__rcvr)) + : __env_rcvr_{static_cast<_Rcvr&&>(__rcvr), static_cast<_Env&&>(__env)} , __opstate_(__async::connect(static_cast<_Sndr&&>(__sndr), &__env_rcvr_)) {}