Skip to content

Commit 25901d7

Browse files
PTX: mbarrier.{test,try}_wait: Fix return value (#3670) (#3672)
* mbarrier.{test,try}_wait: Fix return value (cherry picked from commit f61670e) Co-authored-by: Allard Hendriksen <ahendriksen@nvidia.com>
1 parent 7d3a4d7 commit 25901d7

File tree

12 files changed

+122
-156
lines changed

12 files changed

+122
-156
lines changed

docs/libcudacxx/ptx/instructions/generated/mbarrier_test_wait.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,9 @@ mbarrier.test_wait.relaxed.cta.shared::cta.b64
4747
// .sem = { .relaxed }
4848
// .scope = { .cta, .cluster }
4949
template <cuda::ptx::dot_scope Scope>
50-
__device__ static inline void mbarrier_test_wait(
50+
__device__ static inline bool mbarrier_test_wait(
5151
cuda::ptx::sem_relaxed_t,
5252
cuda::ptx::scope_t<Scope> scope,
53-
bool waitComplete,
5453
uint64_t* addr,
5554
const uint64_t& state);
5655
@@ -62,9 +61,8 @@ mbarrier.test_wait.relaxed.cluster.shared::cta.b64
6261
// .sem = { .relaxed }
6362
// .scope = { .cta, .cluster }
6463
template <cuda::ptx::dot_scope Scope>
65-
__device__ static inline void mbarrier_test_wait(
64+
__device__ static inline bool mbarrier_test_wait(
6665
cuda::ptx::sem_relaxed_t,
6766
cuda::ptx::scope_t<Scope> scope,
68-
bool waitComplete,
6967
uint64_t* addr,
7068
const uint64_t& state);

docs/libcudacxx/ptx/instructions/generated/mbarrier_test_wait_parity.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,9 @@ mbarrier.test_wait.parity.relaxed.cta.shared::cta.b64
4747
// .sem = { .relaxed }
4848
// .scope = { .cta, .cluster }
4949
template <cuda::ptx::dot_scope Scope>
50-
__device__ static inline void mbarrier_test_wait_parity(
50+
__device__ static inline bool mbarrier_test_wait_parity(
5151
cuda::ptx::sem_relaxed_t,
5252
cuda::ptx::scope_t<Scope> scope,
53-
bool waitComplete,
5453
uint64_t* addr,
5554
const uint32_t& phaseParity);
5655
@@ -62,9 +61,8 @@ mbarrier.test_wait.parity.relaxed.cluster.shared::cta.b64
6261
// .sem = { .relaxed }
6362
// .scope = { .cta, .cluster }
6463
template <cuda::ptx::dot_scope Scope>
65-
__device__ static inline void mbarrier_test_wait_parity(
64+
__device__ static inline bool mbarrier_test_wait_parity(
6665
cuda::ptx::sem_relaxed_t,
6766
cuda::ptx::scope_t<Scope> scope,
68-
bool waitComplete,
6967
uint64_t* addr,
7068
const uint32_t& phaseParity);

docs/libcudacxx/ptx/instructions/generated/mbarrier_try_wait.rst

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,9 @@ mbarrier.try_wait.relaxed.cta.shared::cta.b64
8888
// .sem = { .relaxed }
8989
// .scope = { .cta, .cluster }
9090
template <cuda::ptx::dot_scope Scope>
91-
__device__ static inline void mbarrier_try_wait(
91+
__device__ static inline bool mbarrier_try_wait(
9292
cuda::ptx::sem_relaxed_t,
9393
cuda::ptx::scope_t<Scope> scope,
94-
bool waitComplete,
9594
uint64_t* addr,
9695
const uint64_t& state,
9796
const uint32_t& suspendTimeHint);
@@ -104,10 +103,9 @@ mbarrier.try_wait.relaxed.cluster.shared::cta.b64
104103
// .sem = { .relaxed }
105104
// .scope = { .cta, .cluster }
106105
template <cuda::ptx::dot_scope Scope>
107-
__device__ static inline void mbarrier_try_wait(
106+
__device__ static inline bool mbarrier_try_wait(
108107
cuda::ptx::sem_relaxed_t,
109108
cuda::ptx::scope_t<Scope> scope,
110-
bool waitComplete,
111109
uint64_t* addr,
112110
const uint64_t& state,
113111
const uint32_t& suspendTimeHint);
@@ -120,10 +118,9 @@ mbarrier.try_wait.relaxed.cta.shared::cta.b64
120118
// .sem = { .relaxed }
121119
// .scope = { .cta, .cluster }
122120
template <cuda::ptx::dot_scope Scope>
123-
__device__ static inline void mbarrier_try_wait(
121+
__device__ static inline bool mbarrier_try_wait(
124122
cuda::ptx::sem_relaxed_t,
125123
cuda::ptx::scope_t<Scope> scope,
126-
bool waitComplete,
127124
uint64_t* addr,
128125
const uint64_t& state);
129126
@@ -135,9 +132,8 @@ mbarrier.try_wait.relaxed.cluster.shared::cta.b64
135132
// .sem = { .relaxed }
136133
// .scope = { .cta, .cluster }
137134
template <cuda::ptx::dot_scope Scope>
138-
__device__ static inline void mbarrier_try_wait(
135+
__device__ static inline bool mbarrier_try_wait(
139136
cuda::ptx::sem_relaxed_t,
140137
cuda::ptx::scope_t<Scope> scope,
141-
bool waitComplete,
142138
uint64_t* addr,
143139
const uint64_t& state);

docs/libcudacxx/ptx/instructions/generated/mbarrier_try_wait_parity.rst

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,9 @@ mbarrier.try_wait.parity.relaxed.cta.shared::cta.b64
8888
// .sem = { .relaxed }
8989
// .scope = { .cta, .cluster }
9090
template <cuda::ptx::dot_scope Scope>
91-
__device__ static inline void mbarrier_try_wait_parity(
91+
__device__ static inline bool mbarrier_try_wait_parity(
9292
cuda::ptx::sem_relaxed_t,
9393
cuda::ptx::scope_t<Scope> scope,
94-
bool waitComplete,
9594
uint64_t* addr,
9695
const uint32_t& phaseParity,
9796
const uint32_t& suspendTimeHint);
@@ -104,10 +103,9 @@ mbarrier.try_wait.parity.relaxed.cluster.shared::cta.b64
104103
// .sem = { .relaxed }
105104
// .scope = { .cta, .cluster }
106105
template <cuda::ptx::dot_scope Scope>
107-
__device__ static inline void mbarrier_try_wait_parity(
106+
__device__ static inline bool mbarrier_try_wait_parity(
108107
cuda::ptx::sem_relaxed_t,
109108
cuda::ptx::scope_t<Scope> scope,
110-
bool waitComplete,
111109
uint64_t* addr,
112110
const uint32_t& phaseParity,
113111
const uint32_t& suspendTimeHint);
@@ -120,10 +118,9 @@ mbarrier.try_wait.parity.relaxed.cta.shared::cta.b64
120118
// .sem = { .relaxed }
121119
// .scope = { .cta, .cluster }
122120
template <cuda::ptx::dot_scope Scope>
123-
__device__ static inline void mbarrier_try_wait_parity(
121+
__device__ static inline bool mbarrier_try_wait_parity(
124122
cuda::ptx::sem_relaxed_t,
125123
cuda::ptx::scope_t<Scope> scope,
126-
bool waitComplete,
127124
uint64_t* addr,
128125
const uint32_t& phaseParity);
129126
@@ -135,9 +132,8 @@ mbarrier.try_wait.parity.relaxed.cluster.shared::cta.b64
135132
// .sem = { .relaxed }
136133
// .scope = { .cta, .cluster }
137134
template <cuda::ptx::dot_scope Scope>
138-
__device__ static inline void mbarrier_try_wait_parity(
135+
__device__ static inline bool mbarrier_try_wait_parity(
139136
cuda::ptx::sem_relaxed_t,
140137
cuda::ptx::scope_t<Scope> scope,
141-
bool waitComplete,
142138
uint64_t* addr,
143139
const uint32_t& phaseParity);

libcudacxx/include/cuda/__ptx/instructions/generated/mbarrier_test_wait.h

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,49 +89,47 @@ _CCCL_DEVICE static inline bool mbarrier_test_wait(
8989
// .sem = { .relaxed }
9090
// .scope = { .cta, .cluster }
9191
template <cuda::ptx::dot_scope Scope>
92-
__device__ static inline void mbarrier_test_wait(
92+
__device__ static inline bool mbarrier_test_wait(
9393
cuda::ptx::sem_relaxed_t,
9494
cuda::ptx::scope_t<Scope> scope,
95-
bool waitComplete,
9695
uint64_t* addr,
9796
const uint64_t& state);
9897
*/
9998
#if __cccl_ptx_isa >= 860
10099
extern "C" _CCCL_DEVICE void __cuda_ptx_mbarrier_test_wait_is_not_supported_before_SM_90__();
101100
template <dot_scope _Scope>
102-
_CCCL_DEVICE static inline void mbarrier_test_wait(
103-
sem_relaxed_t,
104-
scope_t<_Scope> __scope,
105-
bool __waitComplete,
106-
_CUDA_VSTD::uint64_t* __addr,
107-
const _CUDA_VSTD::uint64_t& __state)
101+
_CCCL_DEVICE static inline bool mbarrier_test_wait(
102+
sem_relaxed_t, scope_t<_Scope> __scope, _CUDA_VSTD::uint64_t* __addr, const _CUDA_VSTD::uint64_t& __state)
108103
{
109104
// __sem == sem_relaxed (due to parameter type constraint)
110105
static_assert(__scope == scope_cta || __scope == scope_cluster, "");
111106
# if _CCCL_CUDA_COMPILER(NVHPC) || __CUDA_ARCH__ >= 900
107+
_CUDA_VSTD::uint32_t __waitComplete;
112108
_CCCL_IF_CONSTEXPR (__scope == scope_cta)
113109
{
114-
asm("{\n\t .reg .pred PRED_waitComplete; \n\t"
115-
"setp.ne.b32 PRED_waitComplete, %0, 0;\n\t"
116-
"mbarrier.test_wait.relaxed.cta.shared::cta.b64 PRED_waitComplete, [%1], %2;\n\t"
110+
asm("{\n\t .reg .pred P_OUT; \n\t"
111+
"mbarrier.test_wait.relaxed.cta.shared::cta.b64 P_OUT, [%1], %2;\n\t"
112+
"selp.b32 %0, 1, 0, P_OUT; \n"
117113
"}"
118-
:
119-
: "r"(static_cast<_CUDA_VSTD::uint32_t>(__waitComplete)), "r"(__as_ptr_smem(__addr)), "l"(__state)
114+
: "=r"(__waitComplete)
115+
: "r"(__as_ptr_smem(__addr)), "l"(__state)
120116
: "memory");
121117
}
122118
else _CCCL_IF_CONSTEXPR (__scope == scope_cluster)
123119
{
124-
asm("{\n\t .reg .pred PRED_waitComplete; \n\t"
125-
"setp.ne.b32 PRED_waitComplete, %0, 0;\n\t"
126-
"mbarrier.test_wait.relaxed.cluster.shared::cta.b64 PRED_waitComplete, [%1], %2;\n\t"
120+
asm("{\n\t .reg .pred P_OUT; \n\t"
121+
"mbarrier.test_wait.relaxed.cluster.shared::cta.b64 P_OUT, [%1], %2;\n\t"
122+
"selp.b32 %0, 1, 0, P_OUT; \n"
127123
"}"
128-
:
129-
: "r"(static_cast<_CUDA_VSTD::uint32_t>(__waitComplete)), "r"(__as_ptr_smem(__addr)), "l"(__state)
124+
: "=r"(__waitComplete)
125+
: "r"(__as_ptr_smem(__addr)), "l"(__state)
130126
: "memory");
131127
}
128+
return static_cast<bool>(__waitComplete);
132129
# else
133130
// Unsupported architectures will have a linker error with a semi-decent error message
134131
__cuda_ptx_mbarrier_test_wait_is_not_supported_before_SM_90__();
132+
return false;
135133
# endif
136134
}
137135
#endif // __cccl_ptx_isa >= 860

libcudacxx/include/cuda/__ptx/instructions/generated/mbarrier_test_wait_parity.h

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,49 +90,47 @@ _CCCL_DEVICE static inline bool mbarrier_test_wait_parity(
9090
// .sem = { .relaxed }
9191
// .scope = { .cta, .cluster }
9292
template <cuda::ptx::dot_scope Scope>
93-
__device__ static inline void mbarrier_test_wait_parity(
93+
__device__ static inline bool mbarrier_test_wait_parity(
9494
cuda::ptx::sem_relaxed_t,
9595
cuda::ptx::scope_t<Scope> scope,
96-
bool waitComplete,
9796
uint64_t* addr,
9897
const uint32_t& phaseParity);
9998
*/
10099
#if __cccl_ptx_isa >= 860
101100
extern "C" _CCCL_DEVICE void __cuda_ptx_mbarrier_test_wait_parity_is_not_supported_before_SM_90__();
102101
template <dot_scope _Scope>
103-
_CCCL_DEVICE static inline void mbarrier_test_wait_parity(
104-
sem_relaxed_t,
105-
scope_t<_Scope> __scope,
106-
bool __waitComplete,
107-
_CUDA_VSTD::uint64_t* __addr,
108-
const _CUDA_VSTD::uint32_t& __phaseParity)
102+
_CCCL_DEVICE static inline bool mbarrier_test_wait_parity(
103+
sem_relaxed_t, scope_t<_Scope> __scope, _CUDA_VSTD::uint64_t* __addr, const _CUDA_VSTD::uint32_t& __phaseParity)
109104
{
110105
// __sem == sem_relaxed (due to parameter type constraint)
111106
static_assert(__scope == scope_cta || __scope == scope_cluster, "");
112107
# if _CCCL_CUDA_COMPILER(NVHPC) || __CUDA_ARCH__ >= 900
108+
_CUDA_VSTD::uint32_t __waitComplete;
113109
_CCCL_IF_CONSTEXPR (__scope == scope_cta)
114110
{
115-
asm("{\n\t .reg .pred PRED_waitComplete; \n\t"
116-
"setp.ne.b32 PRED_waitComplete, %0, 0;\n\t"
117-
"mbarrier.test_wait.parity.relaxed.cta.shared::cta.b64 PRED_waitComplete, [%1], %2;\n\t"
111+
asm("{\n\t .reg .pred P_OUT; \n\t"
112+
"mbarrier.test_wait.parity.relaxed.cta.shared::cta.b64 P_OUT, [%1], %2;\n\t"
113+
"selp.b32 %0, 1, 0, P_OUT; \n"
118114
"}"
119-
:
120-
: "r"(static_cast<_CUDA_VSTD::uint32_t>(__waitComplete)), "r"(__as_ptr_smem(__addr)), "r"(__phaseParity)
115+
: "=r"(__waitComplete)
116+
: "r"(__as_ptr_smem(__addr)), "r"(__phaseParity)
121117
: "memory");
122118
}
123119
else _CCCL_IF_CONSTEXPR (__scope == scope_cluster)
124120
{
125-
asm("{\n\t .reg .pred PRED_waitComplete; \n\t"
126-
"setp.ne.b32 PRED_waitComplete, %0, 0;\n\t"
127-
"mbarrier.test_wait.parity.relaxed.cluster.shared::cta.b64 PRED_waitComplete, [%1], %2;\n\t"
121+
asm("{\n\t .reg .pred P_OUT; \n\t"
122+
"mbarrier.test_wait.parity.relaxed.cluster.shared::cta.b64 P_OUT, [%1], %2;\n\t"
123+
"selp.b32 %0, 1, 0, P_OUT; \n"
128124
"}"
129-
:
130-
: "r"(static_cast<_CUDA_VSTD::uint32_t>(__waitComplete)), "r"(__as_ptr_smem(__addr)), "r"(__phaseParity)
125+
: "=r"(__waitComplete)
126+
: "r"(__as_ptr_smem(__addr)), "r"(__phaseParity)
131127
: "memory");
132128
}
129+
return static_cast<bool>(__waitComplete);
133130
# else
134131
// Unsupported architectures will have a linker error with a semi-decent error message
135132
__cuda_ptx_mbarrier_test_wait_parity_is_not_supported_before_SM_90__();
133+
return false;
136134
# endif
137135
}
138136
#endif // __cccl_ptx_isa >= 860

0 commit comments

Comments
 (0)