diff --git a/extension/aten_util/make_aten_functor_from_et_functor.h b/extension/aten_util/make_aten_functor_from_et_functor.h new file mode 100644 index 00000000000..976099f88fa --- /dev/null +++ b/extension/aten_util/make_aten_functor_from_et_functor.h @@ -0,0 +1,178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===----------------------------------------------------------------------===// +/// \file runtime/kernel/make_aten_functor_from_et_functor.h +/// Defines a template that can be used to create a ATen version of an unboxed +/// ExecuTorch kernel. +//===----------------------------------------------------------------------===// + +#pragma once +#include +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +class KernelRuntimeContext; // Forward declaration +using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove + +template +struct type_map final { + using type = T; +}; + +template <> +struct type_map final { + using type = at::Tensor&; +}; + +template <> +struct type_map final { + using type = const at::Tensor&; +}; + +template +struct type_convert final { + public: + F val; + explicit type_convert(F value) : val(value) {} + T call() { + return static_cast(val); + } +}; + +template +struct remove_const_ref final { + using type = std::remove_const_t>; +}; + +template +struct type_convert< + ATensor, + ETensor, + std::enable_if_t< + std::is_same_v::type, at::Tensor> && + std::is_same_v< + typename remove_const_ref::type, + torch::executor::Tensor>>> + final { + public: + ATensor val; + std::unique_ptr managed_tensor; + torch::executor::Tensor converted; + std::vector sizes; + explicit type_convert(ATensor value) + : val(value), converted(torch::executor::Tensor(nullptr)) { + for (auto size : val.sizes()) { + sizes.push_back(size); + } + torch::executor::ScalarType scalar_type = + static_cast(val.scalar_type()); + managed_tensor = std::make_unique( + val.mutable_data_ptr(), val.numel(), sizes, scalar_type); + converted = managed_tensor->get_aliasing_tensor(); + } + ETensor call() { + return converted; + } +}; + +template <> +struct type_convert final { + public: + torch::executor::Tensor& val; + at::Tensor converted; + std::vector sizes; + explicit type_convert(torch::executor::Tensor& value) : val(value) { + for (auto size : val.sizes()) { + sizes.push_back(size); + } + c10::ScalarType scalar_type = + static_cast(val.scalar_type()); + converted = + at::from_blob(val.mutable_data_ptr(), val.numel(), sizes, scalar_type); + } + at::Tensor& call() { + return converted; + } +}; + +template +struct wrapper_impl; + +template +struct wrapper_impl { + static_assert( + !(std::is_same::value && N == -1), + "Can't wrap a kernel with 'Tensor &' return type without specifying an index to the out tensor"); + using ReturnType = typename type_map::type; + using TupleConvertsType = + std::tuple::type, Args>...>; + using TupleArgsType = std::tuple::type...>; + static constexpr size_t num_args = sizeof...(Args); + static_assert( + (N < num_args && std::is_same_v>, R>) || + N == -1, + "The index of the out tensor can't be greater or equal to num_args and " + "the Nth argument type has to be the same as the return type."); + + static ReturnType wrap(typename type_map::type... args) { + // The wrapped function that takes ATen argument types, convert them into + // ExecuTorch equivalent, call `f` then return the result converted back to + // ATen. + TupleArgsType args_tuple = std::forward_as_tuple(args...); + TupleConvertsType converts = std::forward_as_tuple( + type_convert::type, Args>(args)...); + R result = + call_functor_with_args(converts, std::make_index_sequence()); + typename std::remove_reference::type converted_result = + type_convert(result).call(); + if constexpr (N == -1) { + return converted_result; + } else { + static_assert( + std::is_same_v< + typename std::remove_reference::type, + at::Tensor>, + "Only support at::Tensor-like return"); + ReturnType out = std::get(args_tuple); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; + } + } + + private: + template + static R call_functor_with_args( + TupleConvertsType& converts, + std::index_sequence) { + return f(std::get(converts).call()...); + } +}; + +// Wrapper macro for out variant function. N is the index of the out tensor. +// We need N to know how to preserve the semantics of modifying out tensor and +// return the reference without allocating a new memory buffer for out tensor. +#define _WRAP_2(func, N) \ + wrapper_impl::wrap +#define _WRAP_1(func) wrapper_impl::wrap + +#define GET_MACRO(_1, _2, NAME, ...) NAME +#define WRAP_TO_ATEN(...) GET_MACRO(__VA_ARGS__, _WRAP_2, _WRAP_1)(__VA_ARGS__) + +} // namespace executor +} // namespace torch diff --git a/extension/aten_util/targets.bzl b/extension/aten_util/targets.bzl index a351b7d5a22..b396cb78325 100644 --- a/extension/aten_util/targets.bzl +++ b/extension/aten_util/targets.bzl @@ -10,7 +10,7 @@ def define_common_targets(): runtime.cxx_library( name = "aten_bridge", srcs = ["aten_bridge.cpp"], - exported_headers = ["aten_bridge.h"], + exported_headers = ["aten_bridge.h", "make_aten_functor_from_et_functor.h"], compiler_flags = [ "-frtti", "-fno-omit-frame-pointer", @@ -25,8 +25,10 @@ def define_common_targets(): "//executorch/...", "@EXECUTORCH_CLIENTS", ], - deps = [ + exported_deps = [ + "//executorch/extension/kernel_util:kernel_util", "//executorch/runtime/core:core", + "//executorch/runtime/core:evalue", "//executorch/runtime/core/exec_aten:lib", ], external_deps = [ diff --git a/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp new file mode 100644 index 00000000000..bf6b60fe63c --- /dev/null +++ b/extension/aten_util/test/make_aten_functor_from_et_functor_test.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +using namespace ::testing; + +Tensor& my_op_out(const Tensor& a, Tensor& out) { + (void)a; + return out; +} + +Tensor& add_1_out(const Tensor& a, Tensor& out) { + (void)a; + out.mutable_data_ptr()[0] += 1; + return out; +} + +Tensor& quantized_embedding_byte_out( + const Tensor& weight, + const Tensor& weight_scales, + const Tensor& weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + Tensor& out) { + (void)weight; + (void)weight_scales; + (void)weight_zero_points; + (void)weight_quant_min; + (void)indices; + out.mutable_data_ptr()[0] -= static_cast(weight_quant_max); + return out; +} + +class MakeATenFunctorFromETFunctorTest : public ::testing::Test { + public: + void SetUp() override { + torch::executor::runtime_init(); + } +}; + +TEST_F(MakeATenFunctorFromETFunctorTest, Basic) { + auto function = WRAP_TO_ATEN(my_op_out, 1); + at::Tensor a = torch::tensor({1.0f}); + at::Tensor b = torch::tensor({2.0f}); + at::Tensor c = function(a, b); + EXPECT_EQ(c.const_data_ptr()[0], 2.0f); +} + +TORCH_LIBRARY(my_op, m) { + m.def("add_1.out", WRAP_TO_ATEN(add_1_out, 1)); + m.def( + "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", + WRAP_TO_ATEN(quantized_embedding_byte_out, 6)); +}; + +TEST_F(MakeATenFunctorFromETFunctorTest, RegisterWrappedFunction) { + auto op = c10::Dispatcher::singleton().findSchema({"my_op::add_1", "out"}); + EXPECT_TRUE(op.has_value()); + at::Tensor a = + torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor b = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + torch::jit::Stack stack = {a, b}; + op.value().callBoxed(&stack); + EXPECT_EQ(stack.size(), 1); + EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 3); +} + +TEST_F(MakeATenFunctorFromETFunctorTest, TestEmbeddingByte) { + auto op = + c10::Dispatcher::singleton().findSchema({"my_op::embedding_byte", "out"}); + EXPECT_TRUE(op.has_value()); + at::Tensor weight = + torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor scale = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor zero_point = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor indices = + torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt32)); + at::Tensor out = + torch::tensor({4}, torch::TensorOptions().dtype(torch::kInt32)); + torch::jit::Stack stack = {weight, scale, zero_point, 0, 1, indices, out}; + op.value().callBoxed(&stack); + EXPECT_EQ(stack.size(), 1); + EXPECT_EQ(stack[0].toTensor().const_data_ptr()[0], 3); +} + +} // namespace executor +} // namespace torch diff --git a/extension/aten_util/test/targets.bzl b/extension/aten_util/test/targets.bzl index 6257595e358..58c00a90316 100644 --- a/extension/aten_util/test/targets.bzl +++ b/extension/aten_util/test/targets.bzl @@ -9,22 +9,19 @@ def define_common_targets(): runtime.cxx_test( name = "aten_bridge_test", - srcs = ["aten_bridge_test.cpp"], + srcs = [ + "aten_bridge_test.cpp", + "make_aten_functor_from_et_functor_test.cpp", + ], deps = [ "//executorch/runtime/core:core", "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:operator_registry", "//executorch/extension/aten_util:aten_bridge", + "//executorch/extension/runner_util:managed_tensor", ], - fbcode_deps = [ - "//caffe2:ATen-core", - "//caffe2:ATen-cpu", - "//caffe2/c10:c10", - ], - xplat_deps = [ - "//xplat/caffe2:torch_mobile_core", - "//xplat/caffe2/c10:c10", - # Dont really like this but without this I dont have aten::empty - # And havent figured out a more minimal target - "//xplat/caffe2:torch_mobile_all_ops_et", + external_deps = [ + "libtorch", + "gtest_aten", ], ) diff --git a/extension/runner_util/managed_tensor.h b/extension/runner_util/managed_tensor.h index 63a20081251..a3e20ce7ef1 100644 --- a/extension/runner_util/managed_tensor.h +++ b/extension/runner_util/managed_tensor.h @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #ifdef USE_ATEN_LIB #include