@@ -163,25 +163,21 @@ Tensor& div_scalar_out(
163163
164164 ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " div.Scalar_out" , CTYPE_A, [&]() {
165165 ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, " div.Scalar_out" , CTYPE_B, [&]() {
166- ET_SWITCH_FLOAT_TYPES (
167- common_type, ctx, " div.Scalar_out" , CTYPE_IN, [&]() {
168- ET_SWITCH_FLOAT_TYPES (
169- out_type, ctx, " div.Scalar_out" , CTYPE_OUT, [&]() {
170- CTYPE_B b_val;
171- utils::extract_scalar (b, &b_val);
172- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
173-
174- apply_unary_map_fn (
175- [b_casted](const CTYPE_A val_a) {
176- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
177- CTYPE_IN value = a_casted / b_casted;
178- return static_cast <CTYPE_OUT>(value);
179- },
180- a.const_data_ptr <CTYPE_A>(),
181- out.mutable_data_ptr <CTYPE_OUT>(),
182- out.numel ());
183- });
184- });
166+ ET_SWITCH_FLOAT_TYPES (out_type, ctx, " div.Scalar_out" , CTYPE, [&]() {
167+ CTYPE_B b_val;
168+ utils::extract_scalar (b, &b_val);
169+ CTYPE b_casted = static_cast <CTYPE>(b_val);
170+
171+ apply_unary_map_fn (
172+ [b_casted](const CTYPE_A val_a) {
173+ CTYPE a_casted = static_cast <CTYPE>(val_a);
174+ CTYPE value = a_casted / b_casted;
175+ return static_cast <CTYPE>(value);
176+ },
177+ a.const_data_ptr <CTYPE_A>(),
178+ out.mutable_data_ptr <CTYPE>(),
179+ out.numel ());
180+ });
185181 });
186182 });
187183
@@ -206,7 +202,7 @@ Tensor& div_scalar_mode_out(
206202
207203 ScalarType a_type = a.scalar_type ();
208204 ScalarType b_type = utils::get_scalar_dtype (b);
209- ScalarType common_type = isFloatingType (a_type) ? a_type : ScalarType::Float ;
205+ ScalarType common_type = utils::promote_type_with_scalar (a_type, b) ;
210206 ScalarType out_type = out.scalar_type ();
211207
212208 ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
@@ -215,27 +211,25 @@ Tensor& div_scalar_mode_out(
215211
216212 ET_SWITCH_REALB_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
217213 ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
218- ET_SWITCH_FLOAT_TYPES (common_type, ctx, name, CTYPE_IN, [&]() {
219- ET_SWITCH_FLOAT_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
220- CTYPE_B b_val;
221- utils::extract_scalar (b, &b_val);
222- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
223-
224- apply_unary_map_fn (
225- [b_casted, mode](const CTYPE_A val_a) {
226- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
227- CTYPE_IN value = a_casted / b_casted;
228- if (mode.has_value () && mode.value () == " trunc" ) {
229- value = std::trunc (value);
230- } else if (mode.has_value () && mode.value () == " floor" ) {
231- value = utils::floor_divide (a_casted, b_casted);
232- }
233- return static_cast <CTYPE_OUT>(value);
234- },
235- a.const_data_ptr <CTYPE_A>(),
236- out.mutable_data_ptr <CTYPE_OUT>(),
237- out.numel ());
238- });
214+ ET_SWITCH_REAL_TYPES (out_type, ctx, name, CTYPE, [&]() {
215+ CTYPE_B b_val;
216+ utils::extract_scalar (b, &b_val);
217+ CTYPE b_casted = static_cast <CTYPE>(b_val);
218+
219+ apply_unary_map_fn (
220+ [b_casted, mode](const CTYPE_A val_a) {
221+ CTYPE a_casted = static_cast <CTYPE>(val_a);
222+ CTYPE value = a_casted / b_casted;
223+ if (mode.has_value () && mode.value () == " trunc" ) {
224+ value = std::trunc (value);
225+ } else if (mode.has_value () && mode.value () == " floor" ) {
226+ value = utils::floor_divide (a_casted, b_casted);
227+ }
228+ return value;
229+ },
230+ a.const_data_ptr <CTYPE_A>(),
231+ out.mutable_data_ptr <CTYPE>(),
232+ out.numel ());
239233 });
240234 });
241235 });
0 commit comments