forked from NVIDIA/cccl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstateful_operator.cu
More file actions
70 lines (56 loc) · 1.66 KB
/
stateful_operator.cu
File metadata and controls
70 lines (56 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <cuda/__cccl_config>
_CCCL_SUPPRESS_DEPRECATED_PUSH
#include <thrust/detail/config.h>
#if _CCCL_STD_VER >= 2014
# include <async/inclusive_scan/mixin.h>
# include <async/test_policy_overloads.h>
namespace
{
// Custom binary operator for scan:
template <typename T>
struct stateful_operator
{
T offset;
__host__ __device__ T operator()(T v1, T v2)
{
return v1 + v2 + offset;
}
};
// Postfix args overload definition that uses a stateful custom binary operator
template <typename value_type>
struct use_stateful_operator
{
using postfix_args_type = std::tuple< // Single overload:
std::tuple<stateful_operator<value_type>> // bin_op
>;
static postfix_args_type generate_postfix_args()
{
return postfix_args_type{std::make_tuple(stateful_operator<value_type>{value_type{2}})};
}
};
template <typename value_type>
struct invoker
: testing::async::mixin::input::device_vector<value_type>
, testing::async::mixin::output::device_vector<value_type>
, use_stateful_operator<value_type>
, testing::async::inclusive_scan::mixin::invoke_reference::host_synchronous<value_type>
, testing::async::inclusive_scan::mixin::invoke_async::simple
, testing::async::mixin::compare_outputs::assert_almost_equal_if_fp_quiet
{
static std::string description()
{
return "scan with stateful operator";
}
};
} // namespace
template <typename T>
struct test_stateful_operator
{
void operator()(std::size_t num_values) const
{
testing::async::test_policy_overloads<invoker<T>>::run(num_values);
}
};
DECLARE_GENERIC_SIZED_UNITTEST_WITH_TYPES(test_stateful_operator, NumericTypes);
#endif // C++14
_CCCL_SUPPRESS_DEPRECATED_POP