Skip to content

Commit ba660c1

Browse files
committed
Integrating NNLIB optimization for rsqrt op
1 parent 87aa389 commit ba660c1

3 files changed

Lines changed: 58 additions & 1 deletion

File tree

backends/cadence/aot/functions_hifi.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@
101101
kernels:
102102
- arg_meta: null
103103
kernel_name: torch::executor::where_out
104+
105+
- op: rsqrt.out
106+
kernels:
107+
- arg_meta: null
108+
kernel_name: torch::executor::rsqrt_out
104109

105110
# custom ops
106111
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)

backends/cadence/hifi/operators/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ set(_aten_ops__srcs
3636
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
3737
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
3838
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
39+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_rsqrt.cpp"
3940
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp"
4041
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
4142
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp"
@@ -44,7 +45,8 @@ set(_aten_ops__srcs
4445
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp"
4546
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp"
4647
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp"
47-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp")
48+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp"
49+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp")
4850
add_library(aten_ops_cadence ${_aten_ops__srcs})
4951
target_link_libraries(aten_ops_cadence PUBLIC executorch)
5052
target_link_libraries(aten_ops_cadence PRIVATE cadence_kernels)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
#include "kernels.h"
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
namespace {
18+
19+
double rsqrt(double x) {
20+
return 1.0 / std::sqrt(x);
21+
}
22+
23+
} // namespace
24+
25+
Tensor& rsqrt_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
26+
27+
if(in.scalar_type() == ScalarType::Float)
28+
{
29+
WORD32 num_elm = out.numel();
30+
31+
FLOAT32 * __restrict__ p_out = (FLOAT32 * __restrict__ )out.mutable_data_ptr<float>();
32+
const FLOAT32 * __restrict__ p_inp = (const FLOAT32 * __restrict__)in.const_data_ptr<float>();
33+
34+
xa_nn_elm_rsqrt_f32_f32(
35+
p_out,
36+
p_inp,
37+
num_elm);
38+
39+
return out;
40+
}
41+
else
42+
{
43+
return internal::unary_ufunc_realhb_to_floath(rsqrt, ctx, in, out);
44+
}
45+
46+
}
47+
48+
} // namespace native
49+
} // namespace executor
50+
} // namespace torch

0 commit comments

Comments
 (0)