diff --git a/CMakeLists.txt b/CMakeLists.txt index 96739dae34d..7eeda5b8c7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,6 +144,8 @@ option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF) option(EXECUTORCH_BUILD_CUSTOM "Build the custom kernels" OFF) +option(EXECUTORCH_BUILD_CUSTOM_OPS_AOT "Build the custom ops lib for AOT" OFF) + option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension" OFF) @@ -185,12 +187,19 @@ cmake_dependent_option( cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF) +if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT) + set(EXECUTORCH_BUILD_CUSTOM ON) +endif() + if(EXECUTORCH_BUILD_CUSTOM) set(EXECUTORCH_BUILD_OPTIMIZED ON) endif() if(EXECUTORCH_BUILD_CPUINFO) # --- cpuinfo + set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG + ${CMAKE_POSITION_INDEPENDENT_CODE}) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo") set(CPUINFO_BUILD_TOOLS OFF @@ -212,10 +221,15 @@ if(EXECUTORCH_BUILD_CPUINFO) CACHE STRING "") set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog") add_subdirectory("${CPUINFO_SOURCE_DIR}") + set(CMAKE_POSITION_INDEPENDENT_CODE + ${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) endif() if(EXECUTORCH_BUILD_PTHREADPOOL) # --- pthreadpool + set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG + ${CMAKE_POSITION_INDEPENDENT_CODE}) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool") set(PTHREADPOOL_BUILD_TESTS OFF @@ -235,6 +249,8 @@ if(EXECUTORCH_BUILD_PTHREADPOOL) CACHE STRING "") endif() add_subdirectory("${PTHREADPOOL_SOURCE_DIR}") + set(CMAKE_POSITION_INDEPENDENT_CODE + ${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG}) endif() if(NOT PYTHON_EXECUTABLE) @@ -546,6 +562,9 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs custom_ops) endif() + if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT) + list(APPEND _dep_libs custom_ops_aot_lib) + endif() # compile options for pybind set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 90a4f98952a..688ea02d6d7 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -81,7 +81,7 @@ add_library(xnnpack_backend STATIC ${_xnnpack_backend__srcs}) target_link_libraries(xnnpack_backend PRIVATE ${xnnpack_third_party} - executorch + executorch_no_prim_ops xnnpack_schema) target_include_directories(xnnpack_backend diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index c93ea6149ff..09ebd5aeada 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -18,7 +18,7 @@ runtime.python_library( ], deps = [ "//caffe2:torch", - "//executorch/examples/models/llama2/custom_ops:llama_custom_ops_aot_lib", + "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py", ], ) @@ -52,6 +52,7 @@ runtime.python_binary( main_module = "executorch.examples.models.llama2.export_llama", # visibility = ["//executorch/examples/..."], preload_deps = [ + "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_lib", "//executorch/kernels/quantized:aot_lib", ], deps = [ diff --git a/examples/models/llama2/custom_ops/CMakeLists.txt b/examples/models/llama2/custom_ops/CMakeLists.txt index d954b29f67b..5075807b8db 100644 --- a/examples/models/llama2/custom_ops/CMakeLists.txt +++ b/examples/models/llama2/custom_ops/CMakeLists.txt @@ -25,7 +25,7 @@ if(NOT TORCH_ROOT) set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch) endif() -set(_common_compile_options -Wno-deprecated-declarations) +set(_common_compile_options -Wno-deprecated-declarations -fPIC) include(${EXECUTORCH_ROOT}/build/Utils.cmake) include(${EXECUTORCH_ROOT}/build/Codegen.cmake) @@ -44,7 +44,7 @@ include(${EXECUTORCH_SRCS_FILE}) set(_common_include_directories ${EXECUTORCH_ROOT}/..) # Custom op libraries -set(custom_ops_libs extension_module) +set(custom_ops_libs executorch_no_prim_ops) list(APPEND custom_ops_libs pthreadpool) list(APPEND custom_ops_libs cpuinfo) list(APPEND custom_ops_libs cpublas) @@ -76,3 +76,19 @@ target_compile_options(custom_ops PUBLIC ${_common_compile_options} -DET_USE_THREADPOOL) install(TARGETS custom_ops DESTINATION lib) + +if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT) + # Add a AOT library + find_package(Torch CONFIG REQUIRED) + add_library(custom_ops_aot_lib SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp) + target_include_directories(custom_ops_aot_lib + PUBLIC "${_common_include_directories}") + target_include_directories( + custom_ops_aot_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../../../include") + target_link_libraries(custom_ops_aot_lib PUBLIC custom_ops torch) + target_compile_options(custom_ops_aot_lib PUBLIC -Wno-deprecated-declarations + -fPIC -frtti -fexceptions) + + install(TARGETS custom_ops_aot_lib DESTINATION lib) +endif() diff --git a/examples/models/llama2/custom_ops/op_sdpa_aot.cpp b/examples/models/llama2/custom_ops/op_sdpa_aot.cpp new file mode 100644 index 00000000000..ed735406ad5 --- /dev/null +++ b/examples/models/llama2/custom_ops/op_sdpa_aot.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace torch { +namespace executor { + +namespace native { + +Tensor& sdpa_with_kv_cache_out_no_context( + const Tensor& q_projected, + const Tensor& k_projected, + const Tensor& v_projected, + Tensor& key_cache, + Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + exec_aten::RuntimeContext context{}; + return torch::executor::native::sdpa_with_kv_cache_out( + context, + q_projected, + k_projected, + v_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, + output); +} + +at::Tensor sdpa_with_kv_cache_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const c10::optional attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const c10::optional scale) { + auto output = at::empty_like(q_projected); + WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) + (q_projected, + k_projected, + v_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, + output); + return output; +} + +} // namespace native +} // namespace executor +} // namespace torch + +TORCH_LIBRARY(llama, m) { + m.def( + "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " + "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " + "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"); + m.def( + "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " + "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " + "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); +} + +TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { + m.impl( + "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); + m.impl( + "sdpa_with_kv_cache.out", + WRAP_TO_ATEN( + torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); +} diff --git a/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py index 5f11defb11d..bada40220bc 100644 --- a/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/sdpa_with_kv_cache.py @@ -4,21 +4,29 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# Import custom op defined in op_sdpa_aot.cpp. Those ops are using PyTorch +# C++ APIs for registration so here we need to import the shared library. +# This is only needed for OSS. + +import logging +from pathlib import Path + import torch -from torch.library import impl, impl_abstract -custom_ops_lib = torch.library.Library("llama", "DEF") -custom_ops_lib.define( - "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " - "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " - "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor" -) +from torch.library import impl -custom_ops_lib.define( - "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " - "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " - "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)" -) +try: + op = torch.ops.llama.sdpa_with_kv_cache.default + assert op is not None +except: + libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*")) + assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" + logging.info(f"Loading custom ops library: {libs[0]}") + torch.ops.load_library(libs[0]) + op = torch.ops.llama.sdpa_with_kv_cache.default + assert op is not None + +custom_ops_lib = torch.library.Library("llama", "IMPL") def _validate_params( @@ -118,82 +126,3 @@ def sdpa_with_kv_cache_meta( ) return torch.empty_like(query) - - -@impl(custom_ops_lib, "sdpa_with_kv_cache", "CompositeExplicitAutograd") -def sdpa_with_kv_cache( - query, - key, - value, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask=None, - drpout_p=0.0, - is_causal=False, - scale=None, -): - _validate_params( - query, - key, - value, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - drpout_p, - is_causal, - scale, - ) - - if attn_mask is not None: - attn_mask = attn_mask[start_pos].view((1, -1)) - attn_mask = attn_mask[:, : start_pos + seq_len] - q = query.transpose(1, 2) - key_cache[:, start_pos] = key - value_cache[:, start_pos] = value - - sliced_k_cache = key_cache - sliced_v_cache = value_cache - sliced_k_cache = sliced_k_cache[:, : start_pos + seq_len, :, :] - sliced_v_cache = sliced_v_cache[:, : start_pos + seq_len, :, :] - sliced_k_cache = sliced_k_cache.transpose(1, 2) - sliced_v_cache = sliced_v_cache.transpose(1, 2) - out = torch.nn.functional.scaled_dot_product_attention( - q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask - ) - out = out.transpose(1, 2) - return out - - -@impl_abstract("llama::sdpa_with_kv_cache.out") -def sdpa_with_kv_cache_out( - query, - key, - value, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - drpout_p, - is_causal, - scale, - out, -): - out = sdpa_with_kv_cache_meta( - query, - key, - value, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - drpout_p, - is_causal, - scale, - ) - return out diff --git a/examples/models/llama2/custom_ops/targets.bzl b/examples/models/llama2/custom_ops/targets.bzl index 66ce6e0c04a..cac83abe07d 100644 --- a/examples/models/llama2/custom_ops/targets.bzl +++ b/examples/models/llama2/custom_ops/targets.bzl @@ -6,20 +6,6 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ - runtime.python_library( - name = "llama_custom_ops_aot_lib", - srcs = [ - "sdpa_with_kv_cache.py", - ], - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - deps = [ - "//caffe2:torch", - ], - ) - runtime.cxx_library( name = "custom_ops", srcs = ["op_sdpa.cpp"], @@ -44,6 +30,35 @@ def define_common_targets(): force_static = True, ) + runtime.cxx_library( + name = "custom_ops_aot_lib", + srcs = [ + "op_sdpa_aot.cpp", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + external_deps = [ + "libtorch", + ], + deps = [ + ":custom_ops", + "//executorch/extension/aten_util:aten_bridge", + ], + ) + + runtime.python_library( + name = "custom_ops_aot_py", + srcs = [ + "sdpa_with_kv_cache.py", + ], + visibility = ["//executorch/..."], + deps = [ + "//caffe2:torch", + ], + ) + runtime.cxx_test( name = "op_sdpa_test", srcs = [ diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 2a259af59cb..e9650f81814 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -277,7 +277,7 @@ def forward( y = self.wo(y) return y else: - from .custom_ops.sdpa_with_kv_cache import sdpa_with_kv_cache # noqa + from .custom_ops import sdpa_with_kv_cache # noqa output = torch.ops.llama.sdpa_with_kv_cache( q, diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index 98acdf88ca9..cd34eb78e39 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -47,7 +47,7 @@ endif() # Build cpublas. list(TRANSFORM _optimized_cpublas__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(cpublas STATIC ${_optimized_cpublas__srcs}) -target_link_libraries(cpublas PRIVATE executorch eigen_blas) +target_link_libraries(cpublas PRIVATE executorch_no_prim_ops eigen_blas) target_compile_options(cpublas PUBLIC ${_common_compile_options}) # Generate C++ bindings to register kernels into both PyTorch (for AOT) and @@ -61,7 +61,7 @@ message("Generated files ${gen_command_sources}") list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(optimized_kernels ${_optimized_kernels__srcs}) -target_link_libraries(optimized_kernels PRIVATE executorch cpublas) +target_link_libraries(optimized_kernels PRIVATE executorch_no_prim_ops cpublas) target_compile_options(optimized_kernels PUBLIC ${_common_compile_options}) # Build a library for _optimized_kernels_srcs # diff --git a/setup.py b/setup.py index bef57764b9d..26824826d96 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,11 @@ def pybindings(cls) -> bool: def xnnpack(cls) -> bool: return cls._is_env_enabled("EXECUTORCH_BUILD_XNNPACK", default=False) + @classmethod + @property + def llama_custom_ops(cls) -> bool: + return cls._is_env_enabled("EXECUTORCH_BUILD_CUSTOM_OPS_AOT", default=True) + class _BaseExtension(Extension): """A base class that maps an abstract source to an abstract destination.""" @@ -380,6 +385,11 @@ def run(self): # into the portable_lib target. # TODO(dbort): Add MPS/CoreML backends when building on macos. + if ShouldBuild.llama_custom_ops: + cmake_args += [ + "-DEXECUTORCH_BUILD_CUSTOM_OPS_AOT=ON", + ] + build_args += ["--target", "custom_ops_aot_lib"] # Allow adding extra cmake args through the environment. Used by some # tests and demos to expand the set of targets included in the pip # package. @@ -437,6 +447,14 @@ def get_ext_modules() -> list[Extension]: "portable_lib.*", "executorch.extension.pybindings.portable_lib" ) ) + if ShouldBuild.llama_custom_ops: + ext_modules.append( + # Install the prebuilt library for custom ops used in llama. + BuiltFile( + "examples/models/llama2/custom_ops/libcustom_ops_aot_lib.*", + "executorch/examples/models/llama2/custom_ops", + ) + ) # Note that setuptools uses the presence of ext_modules as the main signal # that a wheel is platform-specific. If we install any platform-specific