Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cudax/include/cuda/experimental/__async/sender/env.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT prop
prop& operator=(const prop&) = delete;
};

template <class _Query, class _Value>
prop(_Query, _Value) -> prop<_Query, _Value>;

template <class... _Envs>
struct _CCCL_TYPE_VISIBILITY_DEFAULT env
{
Expand Down
2 changes: 2 additions & 0 deletions cudax/include/cuda/experimental/__async/sender/meta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct _IN_ALGORITHM;

struct _WHAT;

struct _WHY;

struct _WITH_FUNCTION;

struct _WITH_SENDER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ struct __rcvr_with_env_t<_Rcvr*, _Env>
_Rcvr* __rcvr_;
_Env __env_;
};

template <class _Rcvr, class _Env>
__rcvr_with_env_t(_Rcvr, _Env) -> __rcvr_with_env_t<_Rcvr, _Env>;

} // namespace cuda::experimental::__async

#include <cuda/experimental/__async/sender/epilogue.cuh>
Expand Down
47 changes: 33 additions & 14 deletions cudax/include/cuda/experimental/__async/sender/sync_wait.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
// run_loop isn't supported on-device yet, so neither can sync_wait be.
#if !defined(__CUDA_ARCH__)

# include <cuda/std/__type_traits/type_identity.h>
# include <cuda/std/optional>
# include <cuda/std/tuple>

# include <cuda/experimental/__async/sender/exception.cuh>
# include <cuda/experimental/__async/sender/meta.cuh>
# include <cuda/experimental/__async/sender/run_loop.cuh>
# include <cuda/experimental/__async/sender/utility.cuh>
# include <cuda/experimental/__async/sender/write_env.cuh>

# include <system_error>

Expand Down Expand Up @@ -115,26 +117,38 @@ private:
}
};

using __values_t = value_types_of_t<_Sndr, __rcvr_t, _CUDA_VSTD::tuple, _CUDA_VSTD::__type_self_t>;
using __completions_t = completion_signatures_of_t<_Sndr, __rcvr_t>;

struct __on_success
{
using type = __value_types<__completions_t, _CUDA_VSTD::tuple, _CUDA_VSTD::__type_self_t>;
};

using __on_error = _CUDA_VSTD::type_identity<_CUDA_VSTD::tuple<__completions_t>>;

using __values_t =
typename _CUDA_VSTD::_If<__is_completion_signatures<__completions_t>, __on_success, __on_error>::type;

_CUDA_VSTD::optional<__values_t>* __values_;
::std::exception_ptr __eptr_;
run_loop __loop_;
};

struct __invalid_sync_wait
template <class _Type>
struct __always_false : _CUDA_VSTD::false_type
{};

template <class _Diagnostic>
struct __bad_sync_wait
{
const __invalid_sync_wait& value() const
{
return *this;
}
static_assert(__always_false<_Diagnostic>(),
"sync_wait cannot compute the completions of the sender passed to it.");
static __bad_sync_wait __result();

const __invalid_sync_wait& operator*() const
{
return *this;
}
const __bad_sync_wait& value() const;
const __bad_sync_wait& operator*() const;

int __i_;
int i{}; // so that structured bindings kinda work
};

public:
Expand Down Expand Up @@ -168,12 +182,11 @@ public:
{
using __rcvr_t = typename __state_t<_Sndr>::__rcvr_t;
using __values_t = typename __state_t<_Sndr>::__values_t;
using __completions = completion_signatures_of_t<_Sndr, __rcvr_t>;
static_assert(__is_completion_signatures<__completions>);
using __completions = typename __state_t<_Sndr>::__completions_t;

if constexpr (!__is_completion_signatures<__completions>)
{
return __invalid_sync_wait{0};
return __bad_sync_wait<__completions>::__result();
}
else
{
Expand All @@ -196,6 +209,12 @@ public:
return __result; // uses NRVO to "return" the result
}
}

template <class _Sndr, class _Env>
auto operator()(_Sndr&& __sndr, _Env&& __env) const
{
return (*this)(__async::write_env(static_cast<_Sndr&&>(__sndr), static_cast<_Env&&>(__env)));
}
};

_CCCL_GLOBAL_CONSTANT sync_wait_t sync_wait{};
Expand Down
10 changes: 4 additions & 6 deletions cudax/include/cuda/experimental/__async/sender/utility.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__tuple_dir/ignore.h>
#include <cuda/std/__type_traits/decay.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/initializer_list>

Expand All @@ -34,11 +36,7 @@ namespace cuda::experimental::__async
{
_CCCL_GLOBAL_CONSTANT size_t __npos = static_cast<size_t>(-1);

struct __ignore
{
template <class... _As>
_CUDAX_API constexpr __ignore(_As&&...) noexcept {};
};
using __ignore _CCCL_NODEBUG_ALIAS = _CUDA_VSTD::__ignore_t; // NOLINT: misc-unused-using-decls

using _CUDA_VSTD::__undefined; // NOLINT: misc-unused-using-decls

Expand Down Expand Up @@ -116,7 +114,7 @@ _CUDAX_API constexpr void __swap(_Ty& __left, _Ty& __right) noexcept
}

template <class _Ty>
_CUDAX_API constexpr _Ty __decay_copy(_Ty&& __ty) noexcept(__nothrow_decay_copyable<_Ty>)
_CUDAX_API constexpr _CUDA_VSTD::decay_t<_Ty> __decay_copy(_Ty&& __ty) noexcept(__nothrow_decay_copyable<_Ty>)
{
return static_cast<_Ty&&>(__ty);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private:
connect_result_t<_Sndr, __rcvr_with_env_t<_Rcvr, _Env>*> __opstate_;

_CUDAX_API explicit __opstate_t(_Sndr&& __sndr, _Env __env, _Rcvr __rcvr)
: __env_rcvr_(static_cast<_Env&&>(__env), static_cast<_Rcvr&&>(__rcvr))
: __env_rcvr_{static_cast<_Rcvr&&>(__rcvr), static_cast<_Env&&>(__env)}
, __opstate_(__async::connect(static_cast<_Sndr&&>(__sndr), &__env_rcvr_))
{}

Expand Down
Loading