diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f44b650aa1..42ff6f06e66 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1124,6 +1124,8 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING) endif() if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/memory_allocator) + list(APPEND _executorch_extensions extension_memory_allocator) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner) list(APPEND _executorch_extensions extension_llm_runner) endif() diff --git a/extension/llm/runner/CMakeLists.txt b/extension/llm/runner/CMakeLists.txt index 655c2610ade..43b89f0a908 100644 --- a/extension/llm/runner/CMakeLists.txt +++ b/extension/llm/runner/CMakeLists.txt @@ -39,8 +39,9 @@ add_subdirectory( ${CMAKE_CURRENT_BINARY_DIR}/../sampler ) -set(runner_deps executorch_core extension_module extension_tensor - extension_llm_sampler tokenizers::tokenizers +set(runner_deps + executorch_core extension_module extension_tensor extension_llm_sampler + extension_memory_allocator tokenizers::tokenizers ) # depend on arange_utils diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index bf6de1cee68..4d34fd716e3 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -226,12 +227,28 @@ std::unique_ptr create_text_llm_runner( // Create the Module std::unique_ptr module; + uint32_t max_cached_memory_size_bytes_ = 1024 * 1024 * 10; // 10MB if (data_files.size() > 0) { module = std::make_unique( - model_path, data_files, load_mode, std::move(event_tracer)); + model_path, + data_files, + load_mode, + std::move(event_tracer), + nullptr, // memory allocator + std::make_unique< + executorch::extension::CPUCachingAllocator>( // temp memory + // allocator + max_cached_memory_size_bytes_)); } else { module = std::make_unique( - model_path, load_mode, std::move(event_tracer)); + model_path, + load_mode, + std::move(event_tracer), // event tracer + nullptr, // memory allocator + std::make_unique< + executorch::extension::CPUCachingAllocator>( // temp memory + // allocator + max_cached_memory_size_bytes_)); } // Get metadata from Module diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 2c9000d0137..0d4ed99308d 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -132,6 +132,7 @@ def define_common_targets(): ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, + "//executorch/extension/memory_allocator:cpu_caching_allocator", "//pytorch/tokenizers:hf_tokenizer", "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:sentencepiece",