From 27f2ef05f63a9c606d20ce4231c9559f7b99c9d4 Mon Sep 17 00:00:00 2001 From: LindseyMei <648816901@qq.com> Date: Wed, 1 Jul 2026 07:40:45 +0000 Subject: [PATCH] feat: support lp_norm operator on metax Add MetaX backend for lp_norm: - Create metax/lp_norm_metax.h and metax/lp_norm_metax.maca - Wire MetaX dispatch into operator.cc (CREATE/GET/CALCULATE/DELETE) - Make lp_norm/cuda/kernel.cuh compile on MACA by excluding MetaX from the CUDART_VERSION >= 12090 ::cuda::maximum() branch Validation: python3 test/infiniop/lp_norm.py --metax passes on MetaX C500 (F16/BF16/F32, contiguous and strided axis=-1 cases). Signed-off-by: LindseyMei <648816901@qq.com> --- src/infiniop/ops/lp_norm/cuda/kernel.cuh | 14 +- .../ops/lp_norm/metax/lp_norm_metax.h | 8 + .../ops/lp_norm/metax/lp_norm_metax.maca | 167 ++++++++++++++++++ src/infiniop/ops/lp_norm/operator.cc | 15 ++ 4 files changed, 196 insertions(+), 8 deletions(-) create mode 100644 src/infiniop/ops/lp_norm/metax/lp_norm_metax.h create mode 100644 src/infiniop/ops/lp_norm/metax/lp_norm_metax.maca diff --git a/src/infiniop/ops/lp_norm/cuda/kernel.cuh b/src/infiniop/ops/lp_norm/cuda/kernel.cuh index 2c0c3c151..c3504187a 100644 --- a/src/infiniop/ops/lp_norm/cuda/kernel.cuh +++ b/src/infiniop/ops/lp_norm/cuda/kernel.cuh @@ -17,11 +17,10 @@ __device__ void blockLPNormKernel( local_max = max(local_max, fabsf((float)input[tid + ind * stride])); } __shared__ float global_max; -#if CUDART_VERSION >= 12090 +#if CUDART_VERSION >= 12090 && !defined(ENABLE_METAX_API) float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum()); -#elif defined(ENABLE_HYGON_API) - float max_block = BlockReduce(temp_storage).Reduce( - local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE); +#elif defined(ENABLE_HYGON_API) || defined(ENABLE_METAX_API) + float max_block = BlockReduce(temp_storage).Reduce(local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE); #else float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max()); #endif @@ -76,11 +75,10 @@ __device__ void blockLPNormStridesKernel( local_max = max(local_max, fabsf((float)input[ind_i + ind])); } __shared__ float global_max; -#if CUDART_VERSION >= 12090 +#if CUDART_VERSION >= 12090 && !defined(ENABLE_METAX_API) float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum()); -#elif defined(ENABLE_HYGON_API) - float max_block = BlockReduce(temp_storage).Reduce( - local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE); +#elif defined(ENABLE_HYGON_API) || defined(ENABLE_METAX_API) + float max_block = BlockReduce(temp_storage).Reduce(local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE); #else float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max()); #endif diff --git a/src/infiniop/ops/lp_norm/metax/lp_norm_metax.h b/src/infiniop/ops/lp_norm/metax/lp_norm_metax.h new file mode 100644 index 000000000..0812fd091 --- /dev/null +++ b/src/infiniop/ops/lp_norm/metax/lp_norm_metax.h @@ -0,0 +1,8 @@ +#ifndef __LP_NORM_METAX_H__ +#define __LP_NORM_METAX_H__ + +#include "../lp_norm.h" + +DESCRIPTOR(metax) + +#endif // __LP_NORM_METAX_H__ diff --git a/src/infiniop/ops/lp_norm/metax/lp_norm_metax.maca b/src/infiniop/ops/lp_norm/metax/lp_norm_metax.maca new file mode 100644 index 000000000..56778fdb4 --- /dev/null +++ b/src/infiniop/ops/lp_norm/metax/lp_norm_metax.maca @@ -0,0 +1,167 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" +#include "lp_norm_metax.h" + +template +INFINIOP_METAX_KERNEL blockLPNorm( + Tdata *y, const Tdata *x, + float p, size_t dimsize, + ptrdiff_t stride, float eps) { + blockLPNormKernel(x, y, p, dimsize, stride, eps); +} + +template +INFINIOP_METAX_KERNEL blockLPNormStrides( + Tdata *y, const Tdata *x, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_strides, + const size_t *shape, int ndim, + float p, size_t dimsize, float eps) { + blockLPNormStridesKernel( + x, y, output_strides, input_strides, shape, ndim, p, dimsize, eps); +} + +template +INFINIOP_METAX_KERNEL warpLPNorm( + Tdata *y, const Tdata *x, + float p, size_t othersize, size_t dimsize, + ptrdiff_t stride, float eps) { + warpLPNormKernel(x, y, p, othersize, dimsize, stride, eps); +} + +template +INFINIOP_METAX_KERNEL warpLPNormStrides( + Tdata *y, const Tdata *x, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_strides, + const size_t *shape, int ndim, + float p, size_t othersize, size_t dimsize, + float eps) { + warpLPNormStridesKernel( + x, y, output_strides, input_strides, shape, ndim, p, othersize, dimsize, eps); +} + +namespace op::lp_norm::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis, + int p, + float eps) { + auto info = LPNormInfo::createLPNormInfo(y_desc, x_desc, axis, p, eps); + CHECK_RESULT(info); + size_t workspace_size = y_desc->ndim() * (sizeof(ptrdiff_t) * 2 + sizeof(size_t)); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + const LPNormInfo &info, Tdata *y, const Tdata *x, + hcStream_t stream, void *workspace) { + size_t dimsize = info.dimsize; + size_t othersize = info.othersize; + float p_f = static_cast(info.p); + float eps = info.eps; + int num_blocks = static_cast(info.othersize); + ptrdiff_t stride = info.stride; + int ndim = static_cast(info.ndim); + + char *workspace_ptr = reinterpret_cast(workspace); + ptrdiff_t *input_strides_cuda = reinterpret_cast(workspace_ptr); + ptrdiff_t *output_strides_cuda = input_strides_cuda + ndim; + size_t ptrdiff_array_size = 2 * ndim * sizeof(ptrdiff_t); + size_t *shape_cuda = reinterpret_cast(workspace_ptr + ptrdiff_array_size); + + CHECK_METAX(hcMemcpyAsync(input_strides_cuda, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, hcMemcpyHostToDevice, stream)); + CHECK_METAX(hcMemcpyAsync(output_strides_cuda, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, hcMemcpyHostToDevice, stream)); + CHECK_METAX(hcMemcpyAsync(shape_cuda, info.input_shape.data(), sizeof(size_t) * ndim, hcMemcpyHostToDevice, stream)); + + if (info.continuous) { + if (dimsize > 1024) { + blockLPNorm + <<>>(y, x, p_f, dimsize, stride, eps); + } else { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpLPNorm + <<>>(y, x, p_f, othersize, dimsize, stride, eps); + } + } else { + if (info.axis == ndim - 1) { + if (dimsize > 1024) { + blockLPNormStrides + <<>>( + y, x, output_strides_cuda, input_strides_cuda, shape_cuda, ndim, + p_f, dimsize, eps); + } else { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpLPNormStrides + <<>>( + y, x, output_strides_cuda, input_strides_cuda, shape_cuda, ndim, + p_f, othersize, dimsize, eps); + } + } else { + return INFINI_STATUS_BAD_PARAM; + } + } + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream_) const { + hcStream_t stream = (hcStream_t)stream_; + +#define CALCULATE_LP_NORM(BLOCK_SIZE, TDATA) \ + launchKernel(_info, (TDATA *)y, (const TDATA *)x, stream, workspace) + +#define CALCULATE_LP_NORM_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_LP_NORM(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_LP_NORM(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_LP_NORM(BLOCK_SIZE, __nv_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CALCULATE_LP_NORM_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024) + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { + CALCULATE_LP_NORM_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + +#undef CALCULATE_LP_NORM_WITH_BLOCK_SIZE +#undef CALCULATE_LP_NORM + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::lp_norm::metax diff --git a/src/infiniop/ops/lp_norm/operator.cc b/src/infiniop/ops/lp_norm/operator.cc index c252e6d51..ba779da01 100644 --- a/src/infiniop/ops/lp_norm/operator.cc +++ b/src/infiniop/ops/lp_norm/operator.cc @@ -5,6 +5,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API) || defined(ENABLE_HYGON_API) #include "nvidia/lp_norm_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/lp_norm_metax.h" +#endif __INFINI_C infiniStatus_t infiniopCreateLPNormDescriptor( infiniopHandle_t handle, @@ -42,6 +45,9 @@ __INFINI_C infiniStatus_t infiniopCreateLPNormDescriptor( #ifdef ENABLE_HYGON_API CREATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -72,6 +78,9 @@ __INFINI_C infiniStatus_t infiniopGetLPNormWorkspaceSize(infiniopLPNormDescripto #ifdef ENABLE_HYGON_API GET(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -115,6 +124,9 @@ __INFINI_C infiniStatus_t infiniopLPNorm( #ifdef ENABLE_HYGON_API CALCULATE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -148,6 +160,9 @@ infiniopDestroyLPNormDescriptor(infiniopLPNormDescriptor_t desc) { #ifdef ENABLE_HYGON_API DELETE(INFINI_DEVICE_HYGON, nvidia); #endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;