@@ -20,11 +20,6 @@ namespace native {
2020namespace {
2121
2222ScalarType get_compute_type (ScalarType a_type, ScalarType b_type) {
23- ET_CHECK (
24- !isComplexType (a_type) && !isQIntType (a_type) && !isBitsType (a_type));
25- ET_CHECK (
26- !isComplexType (b_type) && !isQIntType (b_type) && !isBitsType (b_type));
27-
2823 if (isFloatingType (a_type) && isFloatingType (b_type)) {
2924 return promoteTypes (a_type, b_type);
3025 } else if (isFloatingType (a_type)) {
@@ -47,6 +42,18 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
4742
4843 ScalarType a_type = a.scalar_type ();
4944 ScalarType b_type = b.scalar_type ();
45+
46+ ET_KERNEL_CHECK (
47+ ctx,
48+ !isComplexType (a_type) && !isQIntType (a_type) && !isBitsType (a_type),
49+ InvalidArgument,
50+ out);
51+ ET_KERNEL_CHECK (
52+ ctx,
53+ !isComplexType (b_type) && !isQIntType (b_type) && !isBitsType (b_type),
54+ InvalidArgument,
55+ out);
56+
5057 ScalarType common_type = get_compute_type (a_type, b_type);
5158 ScalarType out_type = out.scalar_type ();
5259
@@ -94,7 +101,11 @@ Tensor& div_out_mode(
94101
95102 // Allow casting float -> integral here
96103 // non-bool -> bool is still disallowed
97- ET_CHECK (!(common_type != ScalarType::Bool && out_type == ScalarType::Bool));
104+ ET_KERNEL_CHECK (
105+ ctx,
106+ !(common_type != ScalarType::Bool && out_type == ScalarType::Bool),
107+ InvalidArgument,
108+ out);
98109
99110 ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.out_mode" , CTYPE_A, [&]() {
100111 ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " div.out_mode" , CTYPE_B, [&]() {
@@ -131,15 +142,19 @@ Tensor& div_scalar_out(
131142 (void )ctx;
132143
133144 // Resize for dynamic shape
134- auto error = resize_tensor (out, a.sizes ());
135- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
145+ ET_KERNEL_CHECK_MSG (
146+ ctx,
147+ resize_tensor (out, a.sizes ()) == Error::Ok,
148+ InvalidArgument,
149+ out,
150+ " Failed to resize output tensor." );
136151
137152 ScalarType a_type = a.scalar_type ();
138153 ScalarType b_type = utils::get_scalar_dtype (b);
139154 ScalarType common_type = isFloatingType (a_type) ? a_type : ScalarType::Float;
140155 ScalarType out_type = out.scalar_type ();
141156
142- ET_CHECK ( common_type == out_type);
157+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out );
143158
144159 ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.Scalar_out" , CTYPE_A, [&]() {
145160 ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, " div.Scalar_out" , CTYPE_B, [&]() {
0 commit comments