From c5394d2d6a967ab8913348bbbee533607db4c4e6 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 10 Apr 2024 13:49:17 -0700 Subject: [PATCH] Custom ops API small fixes (#2936) Summary: Fix the way we use `at::from_blob()` and add proper namespace to `CompileTimeFunctionPointer` so to not confused with `at::CompileTimeFunctionPointer`. bypass-github-pytorch-ci-checks bypass-export-ci-checks Reviewed By: lucylq Differential Revision: D55907751 --- extension/aten_util/make_aten_functor_from_et_functor.h | 3 +-- extension/aten_util/targets.bzl | 1 + extension/kernel_util/meta_programming.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h index 976549af8db..92d19c04843 100644 --- a/extension/aten_util/make_aten_functor_from_et_functor.h +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -149,8 +149,7 @@ struct type_convert< } c10::ScalarType scalar_type = static_cast(val.scalar_type()); - converted = - at::from_blob(val.mutable_data_ptr(), val.numel(), sizes, scalar_type); + converted = at::from_blob(val.mutable_data_ptr(), sizes, scalar_type); } ATensor call() { return converted; diff --git a/extension/aten_util/targets.bzl b/extension/aten_util/targets.bzl index b396cb78325..6e325830292 100644 --- a/extension/aten_util/targets.bzl +++ b/extension/aten_util/targets.bzl @@ -27,6 +27,7 @@ def define_common_targets(): ], exported_deps = [ "//executorch/extension/kernel_util:kernel_util", + "//executorch/extension/runner_util:managed_tensor", "//executorch/runtime/core:core", "//executorch/runtime/core:evalue", "//executorch/runtime/core/exec_aten:lib", diff --git a/extension/kernel_util/meta_programming.h b/extension/kernel_util/meta_programming.h index 46262b843ea..c412e907ea0 100644 --- a/extension/kernel_util/meta_programming.h +++ b/extension/kernel_util/meta_programming.h @@ -49,7 +49,7 @@ struct is_compile_time_function_pointer< CompileTimeFunctionPointer> : std::true_type {}; #define EXECUTORCH_FN_TYPE(func) \ - CompileTimeFunctionPointer< \ + ::torch::executor::CompileTimeFunctionPointer< \ std::remove_pointer_t>, \ func> #define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()