Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
use_int32_token=True if args.qnn else False,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
Expand Down Expand Up @@ -746,6 +747,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901

def _load_llama_model_metadata(
weight_type: WeightType,
use_int32_token: bool,
use_kv_cache: bool,
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
Expand All @@ -759,6 +761,7 @@ def _load_llama_model_metadata(
"get_max_seq_len": model_args.max_seq_len,
"get_n_layers": model_args.n_layers,
"get_vocab_size": model_args.vocab_size,
"use_int32_token": use_int32_token,
"use_kv_cache": use_kv_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
"enable_dynamic_shape": enable_dynamic_shape,
Expand All @@ -779,6 +782,7 @@ def _load_llama_model(
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
params_path: str,
use_int32_token: bool = False,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = False,
Expand Down Expand Up @@ -852,6 +856,10 @@ def _load_llama_model(
else:
raise ValueError(f"Unsupported dtype {dtype}")

if use_int32_token:
token = example_inputs[0].to(torch.int32)
example_inputs = (token,) + example_inputs[1:]

return LLMEdgeManager(
model=model,
modelname=modelname,
Expand All @@ -870,6 +878,7 @@ def _load_llama_model(
verbose=verbose,
metadata=_load_llama_model_metadata(
weight_type,
use_int32_token,
use_kv_cache,
use_sdpa_with_kv_cache,
enable_dynamic_shape,
Expand Down
8 changes: 6 additions & 2 deletions examples/models/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
static constexpr auto kVocabSize = "get_vocab_size";
static constexpr auto kUseKVCache = "use_kv_cache";
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
static constexpr auto kUseInt32Token = "use_int32_token";
} // namespace

Runner::Runner(
Expand All @@ -51,6 +52,7 @@ Runner::Runner(
{kMaxSeqLen, 128},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
{kUseInt32Token, true},
}) {
ET_LOG(
Info,
Expand Down Expand Up @@ -128,14 +130,16 @@ Error Runner::load() {
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kEnableDynamicShape));
metadata_.at(kEnableDynamicShape),
metadata_.at(kUseInt32Token));

text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
std::move(eos_ids),
&stats_);
&stats_,
metadata_.at(kUseInt32Token));

return Error::Ok;
}
Expand Down
15 changes: 9 additions & 6 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ namespace llm {
TextPrefiller::TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
bool enable_parallel_prefill)
bool enable_parallel_prefill,
bool use_int32_token)
: text_decoder_runner_(text_decoder_runner),
use_int32_token_(use_int32_token),
use_kv_cache_(use_kv_cache),
enable_parallel_prefill_(enable_parallel_prefill) {}

Expand All @@ -36,12 +38,13 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(

// store the token
uint64_t cur_token;
exec_aten::ScalarType token_type = use_int32_token_
? exec_aten::ScalarType::Int
: exec_aten::ScalarType::Long;
if (enable_parallel_prefill_ || !use_kv_cache_) {
// initialize tensor wrappers
auto tokens = from_blob(
prompt_tokens.data(),
{1, num_prompt_tokens},
exec_aten::ScalarType::Long);
auto tokens =
from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, token_type);

auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
Expand All @@ -60,7 +63,7 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
cur_token = prompt_tokens[0];

// initialize tensor wrappers
auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);
auto tokens = from_blob(&cur_token, {1, 1}, token_type);

auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
Expand Down
4 changes: 3 additions & 1 deletion extension/llm/runner/text_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class ET_EXPERIMENTAL TextPrefiller {
TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache_,
bool enable_parallel_prefill);
bool enable_parallel_prefill,
bool use_int32_token = false);
/**
* Prefill an LLM Module with the given text input.
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
Expand All @@ -40,6 +41,7 @@ class ET_EXPERIMENTAL TextPrefiller {

private:
TextDecoderRunner* text_decoder_runner_;
bool use_int32_token_;
bool use_kv_cache_;
bool enable_parallel_prefill_;
};
Expand Down
11 changes: 8 additions & 3 deletions extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ class ET_EXPERIMENTAL TextTokenGenerator {
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
Stats* stats)
Stats* stats,
bool use_int32_token = false)
: tokenizer_(tokenizer),
text_decoder_runner_(text_decoder_runner),
eos_ids_(std::move(eos_ids)),
use_int32_token_(use_int32_token),
use_kv_cache_(use_kv_cache),
stats_(stats) {}

Expand All @@ -54,6 +56,9 @@ class ET_EXPERIMENTAL TextTokenGenerator {

std::vector<uint64_t> token_data; // allocate space for the tokens
std::vector<executorch::aten::SizesType> token_shape;
exec_aten::ScalarType token_type = use_int32_token_
? exec_aten::ScalarType::Int
: exec_aten::ScalarType::Long;

// Token after prefill
uint64_t cur_token = tokens.back();
Expand All @@ -70,8 +75,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
}

// initialize tensor wrappers
auto tokens_managed = from_blob(
token_data.data(), token_shape, executorch::aten::ScalarType::Long);
auto tokens_managed = from_blob(token_data.data(), token_shape, token_type);
auto start_pos_managed =
from_blob(&pos, {1}, executorch::aten::ScalarType::Long);

Expand Down Expand Up @@ -133,6 +137,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
Tokenizer* tokenizer_;
TextDecoderRunner* text_decoder_runner_;
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
bool use_int32_token_;
bool use_kv_cache_;

// state machine
Expand Down