@@ -461,11 +461,14 @@ operator/(const T0 &x, const complex<T1> &y)
461461
462462// The using declarations allows imports all necessary functions for thurst::complex.
463463// However, they also lead to thrust::abs(1.0F) being valid code after include <thurst/complex.h>.
464+ // We are importing those for the plain value taking overloads and specialize for those taking
465+ // or returning a `thrust::complex` below
464466using ::cuda::std::abs;
465467using ::cuda::std::arg;
466468using ::cuda::std::conj;
467469using ::cuda::std::norm;
468- using ::cuda::std::polar;
470+ // polar only takes a T but returns a complex<T> so we cannot pull that one in.
471+ // using ::cuda::std::polar;
469472using ::cuda::std::proj;
470473
471474using ::cuda::std::exp;
@@ -487,6 +490,90 @@ using ::cuda::std::sinh;
487490using ::cuda::std::tan;
488491using ::cuda::std::tanh;
489492
493+ // Those functions return `cuda::std::complex<T>` so we must provide an explicit overload that returns `thrust::complex<T>`
494+ template <class T >
495+ __host__ __device__ complex <T> conj (const complex <T>& c) {
496+ return static_cast <complex <T>>(::cuda::std::conj (c));
497+ }
498+ template <class T >
499+ __host__ __device__ complex <T> polar (const T& rho, const T& theta = T{}) {
500+ return static_cast <complex <T>>(::cuda::std::polar (rho, theta));
501+ }
502+ template <class T >
503+ __host__ __device__ complex <T> proj (const complex <T>& c) {
504+ return static_cast <complex <T>>(::cuda::std::proj (c));
505+ }
506+
507+ template <class T >
508+ __host__ __device__ complex <T> exp (const complex <T>& c) {
509+ return static_cast <complex <T>>(::cuda::std::exp (c));
510+ }
511+ template <class T >
512+ __host__ __device__ complex <T> log (const complex <T>& c) {
513+ return static_cast <complex <T>>(::cuda::std::log (c));
514+ }
515+ template <class T >
516+ __host__ __device__ complex <T> log10 (const complex <T>& c) {
517+ return static_cast <complex <T>>(::cuda::std::log10 (c));
518+ }
519+ template <class T >
520+ __host__ __device__ complex <T> pow (const complex <T>& c) {
521+ return static_cast <complex <T>>(::cuda::std::pow (c));
522+ }
523+ template <class T >
524+ __host__ __device__ complex <T> sqrt (const complex <T>& c) {
525+ return static_cast <complex <T>>(::cuda::std::sqrt (c));
526+ }
527+
528+ template <class T >
529+ __host__ __device__ complex <T> acos (const complex <T>& c) {
530+ return static_cast <complex <T>>(::cuda::std::acos (c));
531+ }
532+ template <class T >
533+ __host__ __device__ complex <T> acosh (const complex <T>& c) {
534+ return static_cast <complex <T>>(::cuda::std::acosh (c));
535+ }
536+ template <class T >
537+ __host__ __device__ complex <T> asin (const complex <T>& c) {
538+ return static_cast <complex <T>>(::cuda::std::asin (c));
539+ }
540+ template <class T >
541+ __host__ __device__ complex <T> asinh (const complex <T>& c) {
542+ return static_cast <complex <T>>(::cuda::std::asinh (c));
543+ }
544+ template <class T >
545+ __host__ __device__ complex <T> atan (const complex <T>& c) {
546+ return static_cast <complex <T>>(::cuda::std::atan (c));
547+ }
548+ template <class T >
549+ __host__ __device__ complex <T> atanh (const complex <T>& c) {
550+ return static_cast <complex <T>>(::cuda::std::atanh (c));
551+ }
552+ template <class T >
553+ __host__ __device__ complex <T> cos (const complex <T>& c) {
554+ return static_cast <complex <T>>(::cuda::std::cos (c));
555+ }
556+ template <class T >
557+ __host__ __device__ complex <T> cosh (const complex <T>& c) {
558+ return static_cast <complex <T>>(::cuda::std::cosh (c));
559+ }
560+ template <class T >
561+ __host__ __device__ complex <T> sin (const complex <T>& c) {
562+ return static_cast <complex <T>>(::cuda::std::sin (c));
563+ }
564+ template <class T >
565+ __host__ __device__ complex <T> sinh (const complex <T>& c) {
566+ return static_cast <complex <T>>(::cuda::std::sinh (c));
567+ }
568+ template <class T >
569+ __host__ __device__ complex <T> tan (const complex <T>& c) {
570+ return static_cast <complex <T>>(::cuda::std::tan (c));
571+ }
572+ template <class T >
573+ __host__ __device__ complex <T> tanh (const complex <T>& c) {
574+ return static_cast <complex <T>>(::cuda::std::tanh (c));
575+ }
576+
490577template <typename T>
491578struct proclaim_trivially_relocatable <complex <T>> : thrust::true_type
492579{};
0 commit comments