Skip to content

Commit 665b376

Browse files
authored
Implement math functions for thrust::complex (#1178) (#1191)
* Implement math functions for `thrust::complex` We are having issues that the `cuda::std` math functions that take a `cuda::std::complex` return a `cuda::std::complex`. This can lead to issues as we require a conversion sequence from `cuda::std::complex` to `thrust::complex` which e.g is broken by an constructor being explicit. Addresses nvbug4397241
1 parent 265d985 commit 665b376

File tree

2 files changed

+106
-1
lines changed

2 files changed

+106
-1
lines changed

thrust/testing/complex.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,21 @@ struct TestComplexStdComplexDeviceInterop
679679
SimpleUnitTest<TestComplexStdComplexDeviceInterop, FloatingPointTypes>
680680
TestComplexStdComplexDeviceInteropInstance;
681681
#endif
682+
683+
template <typename T>
684+
struct TestComplexExplicitConstruction
685+
{
686+
struct user_complex {
687+
__host__ __device__ user_complex(T, T) {}
688+
__host__ __device__ user_complex(const thrust::complex<T>&) {}
689+
};
690+
691+
void operator()()
692+
{
693+
const thrust::complex<T> input(42.0, 1337.0);
694+
const user_complex result = thrust::exp(input);
695+
(void)result;
696+
}
697+
};
698+
SimpleUnitTest<TestComplexExplicitConstruction, FloatingPointTypes>
699+
TestComplexExplicitConstructionInstance;

thrust/thrust/complex.h

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,11 +461,14 @@ operator/(const T0 &x, const complex<T1> &y)
461461

462462
// The using declarations allows imports all necessary functions for thurst::complex.
463463
// However, they also lead to thrust::abs(1.0F) being valid code after include <thurst/complex.h>.
464+
// We are importing those for the plain value taking overloads and specialize for those taking
465+
// or returning a `thrust::complex` below
464466
using ::cuda::std::abs;
465467
using ::cuda::std::arg;
466468
using ::cuda::std::conj;
467469
using ::cuda::std::norm;
468-
using ::cuda::std::polar;
470+
// polar only takes a T but returns a complex<T> so we cannot pull that one in.
471+
// using ::cuda::std::polar;
469472
using ::cuda::std::proj;
470473

471474
using ::cuda::std::exp;
@@ -487,6 +490,90 @@ using ::cuda::std::sinh;
487490
using ::cuda::std::tan;
488491
using ::cuda::std::tanh;
489492

493+
// Those functions return `cuda::std::complex<T>` so we must provide an explicit overload that returns `thrust::complex<T>`
494+
template<class T>
495+
__host__ __device__ complex<T> conj(const complex<T>& c) {
496+
return static_cast<complex<T>>(::cuda::std::conj(c));
497+
}
498+
template<class T>
499+
__host__ __device__ complex<T> polar(const T& rho, const T& theta = T{}) {
500+
return static_cast<complex<T>>(::cuda::std::polar(rho, theta));
501+
}
502+
template<class T>
503+
__host__ __device__ complex<T> proj(const complex<T>& c) {
504+
return static_cast<complex<T>>(::cuda::std::proj(c));
505+
}
506+
507+
template<class T>
508+
__host__ __device__ complex<T> exp(const complex<T>& c) {
509+
return static_cast<complex<T>>(::cuda::std::exp(c));
510+
}
511+
template<class T>
512+
__host__ __device__ complex<T> log(const complex<T>& c) {
513+
return static_cast<complex<T>>(::cuda::std::log(c));
514+
}
515+
template<class T>
516+
__host__ __device__ complex<T> log10(const complex<T>& c) {
517+
return static_cast<complex<T>>(::cuda::std::log10(c));
518+
}
519+
template<class T>
520+
__host__ __device__ complex<T> pow(const complex<T>& c) {
521+
return static_cast<complex<T>>(::cuda::std::pow(c));
522+
}
523+
template<class T>
524+
__host__ __device__ complex<T> sqrt(const complex<T>& c) {
525+
return static_cast<complex<T>>(::cuda::std::sqrt(c));
526+
}
527+
528+
template<class T>
529+
__host__ __device__ complex<T> acos(const complex<T>& c) {
530+
return static_cast<complex<T>>(::cuda::std::acos(c));
531+
}
532+
template<class T>
533+
__host__ __device__ complex<T> acosh(const complex<T>& c) {
534+
return static_cast<complex<T>>(::cuda::std::acosh(c));
535+
}
536+
template<class T>
537+
__host__ __device__ complex<T> asin(const complex<T>& c) {
538+
return static_cast<complex<T>>(::cuda::std::asin(c));
539+
}
540+
template<class T>
541+
__host__ __device__ complex<T> asinh(const complex<T>& c) {
542+
return static_cast<complex<T>>(::cuda::std::asinh(c));
543+
}
544+
template<class T>
545+
__host__ __device__ complex<T> atan(const complex<T>& c) {
546+
return static_cast<complex<T>>(::cuda::std::atan(c));
547+
}
548+
template<class T>
549+
__host__ __device__ complex<T> atanh(const complex<T>& c) {
550+
return static_cast<complex<T>>(::cuda::std::atanh(c));
551+
}
552+
template<class T>
553+
__host__ __device__ complex<T> cos(const complex<T>& c) {
554+
return static_cast<complex<T>>(::cuda::std::cos(c));
555+
}
556+
template<class T>
557+
__host__ __device__ complex<T> cosh(const complex<T>& c) {
558+
return static_cast<complex<T>>(::cuda::std::cosh(c));
559+
}
560+
template<class T>
561+
__host__ __device__ complex<T> sin(const complex<T>& c) {
562+
return static_cast<complex<T>>(::cuda::std::sin(c));
563+
}
564+
template<class T>
565+
__host__ __device__ complex<T> sinh(const complex<T>& c) {
566+
return static_cast<complex<T>>(::cuda::std::sinh(c));
567+
}
568+
template<class T>
569+
__host__ __device__ complex<T> tan(const complex<T>& c) {
570+
return static_cast<complex<T>>(::cuda::std::tan(c));
571+
}
572+
template<class T>
573+
__host__ __device__ complex<T> tanh(const complex<T>& c) {
574+
return static_cast<complex<T>>(::cuda::std::tanh(c));
575+
}
576+
490577
template <typename T>
491578
struct proclaim_trivially_relocatable<complex<T>> : thrust::true_type
492579
{};

0 commit comments

Comments
 (0)