From 222cb28a4f30797c7444e1307cfcd9f7fba9a49f Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 1 Jul 2026 19:17:16 +0800 Subject: [PATCH] feat: allow embedding and add rms norm on hygon --- src/infinicore/nn/embedding.cc | 2 +- src/infinicore/nn/rmsnorm.cc | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/infinicore/nn/embedding.cc b/src/infinicore/nn/embedding.cc index e624ebd1a..bd2b39926 100644 --- a/src/infinicore/nn/embedding.cc +++ b/src/infinicore/nn/embedding.cc @@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings, Tensor Embedding::forward(const Tensor &indices) const { // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach auto device_type = device_.getType(); - if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ASCEND || device_type == Device::Type::CAMBRICON || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY) { + if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ASCEND || device_type == Device::Type::CAMBRICON || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI || device_type == Device::Type::QY || device_type == Device::Type::HYGON) { // Use op::embedding which supports device-side input and batch dimension return op::embedding(indices->contiguous()->to(device_), weight_); } diff --git a/src/infinicore/nn/rmsnorm.cc b/src/infinicore/nn/rmsnorm.cc index 24d090049..567d88e49 100644 --- a/src/infinicore/nn/rmsnorm.cc +++ b/src/infinicore/nn/rmsnorm.cc @@ -32,7 +32,8 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { || device_.getType() == Device::Type::METAX || device_.getType() == Device::Type::MOORE || device_.getType() == Device::Type::ALI - || device_.getType() == Device::Type::CAMBRICON) { + || device_.getType() == Device::Type::CAMBRICON + || device_.getType() == Device::Type::HYGON) { op::add_rms_norm_inplace(x, residual, weight_, static_cast(eps_)); } else { op::add_(residual, x, residual);