diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 10bd07baee2..f6bfae9bdaa 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include @@ -290,23 +291,18 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index] : + BroadcastIndexesRange<2>(out, a, b)) { + data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]); } + } else { + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + } } } @@ -338,28 +334,16 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } - if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, c_index] : + BroadcastIndexesRange<3>(out, a, b, c)) { + data_out[out_index] = + compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); } - - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 778006f1b99..09db5f7180d 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index] : + BroadcastIndexesRange<2>(out, a, b)) { + auto result = compute_fun( + load_a_to_common(&data_a[a_index * a_element_size]), + load_b_to_common(&data_b[b_index * b_element_size])); + store_common_to_out(result, &data_out[out_index * out_element_size]); + } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); } } @@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } - if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, c_index] : + BroadcastIndexesRange<3>(out, a, b, c)) { + auto result = compute_fun( + load_a_to_common(&data_a[a_index * a_element_size]), + load_b_to_common(&data_b[b_index * b_element_size]), + load_c_to_common(&data_c[c_index * c_element_size])); + store_common_to_out(result, &data_out[out_index * out_element_size]); + } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); } } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index c42f38fd8b0..739bc117fbf 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -70,6 +70,9 @@ def define_common_targets(): exported_headers = [ "broadcast_util.h", ], + exported_deps = [ + ":broadcast_indexes_range", + ], deps = [ ":repeat_util", "//executorch/runtime/kernel:kernel_includes",