From 20d9cacf8440ca149dcd55c97ea860d8718076da Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 5 Apr 2024 13:44:08 -0400 Subject: [PATCH] [KVCache] Initialize one extra page than specified This PR udpates PagedKVCache to initialize one more page than specified via constructor. The reason is that applications usually depends the number of free pages (returned from `GetNumAvailablePages`) to decide the KV cache operation policy. If there is no this extra page, the KV cache will tell "no available" pages even when the last allocated pages are not full, which may give the applications an illusion that the KV cache is already completely full, and cause further issues. --- src/runtime/relax_vm/paged_kv_cache.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index e16d79885e67..0c635967f25d 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1790,7 +1790,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") int64_t prefill_chunk_size = cache_config[2]; int64_t page_size = cache_config[3]; bool support_sliding_window = cache_config[4]; - int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size; + int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; if (support_sliding_window) { // When sliding window is enabled, each sequence may use two more pages at most. num_total_pages += reserved_num_seqs * 2; @@ -1827,7 +1827,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") int64_t prefill_chunk_size = cache_config[2]; int64_t page_size = cache_config[3]; bool support_sliding_window = cache_config[4]; - int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size; + int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size + 1; if (support_sliding_window) { // When sliding window is enabled, each sequence may use two more pages at most. num_total_pages += reserved_num_seqs * 2;