66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < cmath>
9+ // patternlint-disable-next-line executorch-cpp-nostdinc
10+ #include < functional>
1011
12+ #include < executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113#include < executorch/kernels/portable/cpu/scalar_utils.h>
1214#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
1315#include < executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,28 +19,13 @@ namespace torch {
1719namespace executor {
1820namespace native {
1921
20- namespace {
21-
22- template <typename CTYPE>
23- CTYPE bitwise_xor (CTYPE a, CTYPE b) {
24- return a ^ b;
25- }
26-
27- template <>
28- bool bitwise_xor<bool >(bool a, bool b) {
29- return a != b;
30- }
31-
32- } // namespace
33-
3422using Tensor = exec_aten::Tensor;
3523
3624Tensor& bitwise_xor_Tensor_out (
3725 RuntimeContext& ctx,
3826 const Tensor& a,
3927 const Tensor& b,
4028 Tensor& out) {
41- // Determine output size and resize for dynamic shapes
4229 ET_KERNEL_CHECK (
4330 ctx,
4431 resize_to_broadcast_target_size (a, b, out) == Error::Ok,
@@ -56,38 +43,23 @@ Tensor& bitwise_xor_Tensor_out(
5643 Bool, a_type, ctx, " bitwise_xor.Tensor_out" , CTYPE_A, [&]() {
5744 ET_SWITCH_INT_TYPES_AND (
5845 Bool, b_type, ctx, " bitwise_xor.Tensor_out" , CTYPE_B, [&]() {
59- ET_SWITCH_INT_TYPES_AND (
46+ using CTYPE_IN = typename torch::executor::
47+ promote_types<CTYPE_A, CTYPE_B>::type;
48+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
49+ ET_SWITCH_REAL_TYPES_AND (
6050 Bool,
61- common_type ,
51+ out_type ,
6252 ctx,
6353 " bitwise_xor.Tensor_out" ,
64- CTYPE_IN ,
54+ CTYPE_OUT ,
6555 [&]() {
66- ET_SWITCH_REAL_TYPES_AND (
67- Bool,
68- out_type,
69- ctx,
70- " bitwise_xor.Tensor_out" ,
71- CTYPE_OUT,
72- [&]() {
73- apply_binary_elementwise_fn<
74- CTYPE_A,
75- CTYPE_B,
76- CTYPE_OUT>(
77- [](const CTYPE_A val_a, const CTYPE_B val_b) {
78- CTYPE_IN a_casted =
79- static_cast <CTYPE_IN>(val_a);
80- CTYPE_IN b_casted =
81- static_cast <CTYPE_IN>(val_b);
82- CTYPE_IN value =
83- bitwise_xor (a_casted, b_casted);
84-
85- return static_cast <CTYPE_OUT>(value);
86- },
87- a,
88- b,
89- out);
90- });
56+ internal::BitwiseOpInner<
57+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
58+ std::bit_xor,
59+ CTYPE_A,
60+ CTYPE_B,
61+ CTYPE_IN,
62+ CTYPE_OUT>::run (a, b, out);
9163 });
9264 });
9365 });
@@ -143,8 +115,8 @@ Tensor& bitwise_xor_Scalar_out(
143115 static_cast <CTYPE_IN>(val_a);
144116 CTYPE_IN b_casted =
145117 static_cast <CTYPE_IN>(val_b);
146- CTYPE_IN value =
147- bitwise_xor ( a_casted, b_casted);
118+ CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
119+ a_casted, b_casted);
148120
149121 return static_cast <CTYPE_OUT>(value);
150122 },
0 commit comments