diff --git a/src/runtime/relax_vm/kv_cache.h b/src/runtime/relax_vm/kv_cache.h index 4f4b538cb3b4..b201ab93f659 100644 --- a/src/runtime/relax_vm/kv_cache.h +++ b/src/runtime/relax_vm/kv_cache.h @@ -150,6 +150,15 @@ class AttentionKVCache : public Object { virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional mask, NDArray o_data) = 0; + /************** Positions **************/ + + /*! + * \brief Get the in-sequence positions of each slot in the query. + * This function is supposed to be invoked after calling BeginForward. + * \return The in-sequence query positions, in shape `(total_length,)`. + */ + virtual NDArray GetQueryPositions() const = 0; + /************** Debug Helpers **************/ /*! diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index a8c38ca4ed3d..7417d90e02ad 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -744,6 +744,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCache { AttentionInternal(layer_id, q_data, k_data, v_data, o_data); } + NDArray GetQueryPositions() const final { + CHECK(!dirty_aux_data_device_) + << "The auxiliary arrays are not synchronized to device. Please call " + "`BeginForward` to synchronize before calling `GetQueryPositions`."; + return q_rope_position_map_view_; + }; + void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) final { CHECK(f_debug_get_kv_.defined()) @@ -1231,6 +1238,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward") .set_body_method(&PagedAttentionKVCacheObj::BeginForward); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward") .set_body_method(&PagedAttentionKVCacheObj::EndForward); +TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions") + .set_body_method(&PagedAttentionKVCacheObj::GetQueryPositions); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv") .set_body_method(&PagedAttentionKVCacheObj::DebugGetKV); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")