Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions csrc/gpu/get_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

#include "paddle/extension.h"

#define MAX_BSZ 512
#define SPECULATE_MAX_BSZ 256
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6

template <int SIZE>
Expand Down Expand Up @@ -70,8 +69,8 @@ void GetOutput(const paddle::Tensor& x,
static struct MsgData<SIZE> msg_rcv;
GetOutputFunc<SIZE>(msg_rcv, x, rank_id, wait_flag);
} else {
constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS +
SPECULATE_MAX_BSZ +
constexpr int SIZE = MAX_BSZ * MAX_DRAFT_TOKENS +
MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_rcv;
GetOutputFunc<SIZE>(specu_msg_rcv, x, rank_id, wait_flag);
Expand Down
13 changes: 6 additions & 7 deletions csrc/gpu/save_with_output_msg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

#include "paddle/extension.h"

#define MAX_BSZ 512
#define SPECULATE_MAX_BSZ 256
#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6

template <int SIZE>
Expand Down Expand Up @@ -63,15 +62,15 @@ void SaveOutMsgFunc(MsgData<SIZE>& msg_sed, // NOLINT
msg_sed.mtype = 1;
msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1;
msg_sed.mtext[1] = bsz;
for (int i = 2; i < SPECULATE_MAX_BSZ + 2; i++) {
for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
}
}
for (int i = SPECULATE_MAX_BSZ + 2; i < SIZE; i++) {
int token_id = i - SPECULATE_MAX_BSZ - 2;
for (int i = MAX_BSZ + 2; i < SIZE; i++) {
int token_id = i - MAX_BSZ - 2;
int bid = token_id / MAX_DRAFT_TOKENS;
int local_token_id = token_id % MAX_DRAFT_TOKENS;
if (token_id / MAX_DRAFT_TOKENS >= bsz) {
Expand All @@ -97,8 +96,8 @@ void SaveOutMsg(const paddle::Tensor& x,
static struct MsgData<SIZE> msg_sed;
SaveOutMsgFunc<SIZE>(msg_sed, x, not_need_stop, accept_num, rank_id);
} else {
constexpr int SIZE = SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS +
SPECULATE_MAX_BSZ +
constexpr int SIZE = MAX_BSZ * MAX_DRAFT_TOKENS +
MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_sed;
SaveOutMsgFunc<SIZE>(specu_msg_sed, x, not_need_stop, accept_num, rank_id);
Expand Down
6 changes: 3 additions & 3 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
PretrainedTokenizer,
)
from paddlenlp.trl import llm_utils
from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -1039,7 +1039,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
output_tensor_shape = [MAX_BSZ + 2, 1]
else:
read_res_func = llm_utils.speculate_read_res
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
output_tensor_shape = [MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2, 1]

read_res_process = mp.Process(
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
output_tensor_shape = [MAX_BSZ + 2, 1]
else:
read_res_func = llm_utils.speculate_read_res
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
output_tensor_shape = [MAX_BSZ * MAX_DRAFT_TOKENS + MAX_BSZ + 2, 1]

read_res_process = mp.Process(
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
Expand Down
189 changes: 141 additions & 48 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,24 @@
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

def get_output_padding_offset(self, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder):

Check warning on line 558 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L558

Added line #L558 was not covered by tests
"""
In the senerio of speculate decoding, the length of output token after rebuild_padding is no longer bsz.
So we need to calculate the output_padding_offset after rebuild_padding.
"""
from paddlenlp_ops import (

Check warning on line 563 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L563

Added line #L563 was not covered by tests
speculate_get_output_padding_offset,
speculate_get_seq_lens_output,
)

seq_lens_output = speculate_get_seq_lens_output(seq_lens_this_time, seq_lens_encoder, seq_lens_decoder)
out_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(self.max_seq_len - seq_lens_output)
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(

Check warning on line 571 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L568-L571

Added lines #L568 - L571 were not covered by tests
output_cum_offsets_tmp, out_token_num, seq_lens_output, self.max_seq_len
)
return output_padding_offset, output_cum_offsets

Check warning on line 574 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L574

Added line #L574 was not covered by tests

@paddle.no_grad()
def generate(
self,
Expand Down Expand Up @@ -665,66 +683,141 @@
):
step_idx = model_kwargs["step_idx"]
logits = paddle.cast(outputs, paddle.float32)
from paddlenlp_ops import set_preids_token_penalty_multi_scores

set_preids_token_penalty_multi_scores(
model_kwargs["pre_ids"],
model_kwargs["input_ids"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
step_idx,
model_kwargs["stop_flags"],
logits,
penalty_score,
frequency_score,
presence_score,
temperature,
model_kwargs["bad_tokens"],
step_idx,
model_kwargs["min_dec_len"],
eos_token_id,
)
# TODO(Wanglongzhi2001): token_penalty of speculative decoding
if not is_speculative_decoding:
from paddlenlp_ops import set_preids_token_penalty_multi_scores

Check warning on line 689 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L688-L689

Added lines #L688 - L689 were not covered by tests

set_preids_token_penalty_multi_scores(

Check warning on line 691 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L691

Added line #L691 was not covered by tests
model_kwargs["pre_ids"],
model_kwargs["input_ids"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
step_idx,
model_kwargs["stop_flags"],
logits,
penalty_score,
frequency_score,
presence_score,
temperature,
model_kwargs["bad_tokens"],
step_idx,
model_kwargs["min_dec_len"],
eos_token_id,
)

# sample
probs = F.softmax(logits)

# compute next_tokens
if use_faster_top_p_sampling():
from paddlenlp_ops import top_p_sampling_reject
from paddlenlp_ops import save_output

Check warning on line 712 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L712

Added line #L712 was not covered by tests

next_tokens = top_p_sampling_reject(probs, top_p, 0)
# whether speculative decoding
if not is_speculative_decoding:

Check warning on line 715 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L715

Added line #L715 was not covered by tests

# compute next_tokens
if use_faster_top_p_sampling():
from paddlenlp_ops import top_p_sampling_reject

Check warning on line 719 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L718-L719

Added lines #L718 - L719 were not covered by tests

next_tokens = top_p_sampling_reject(probs, top_p, 0)

Check warning on line 721 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L721

Added line #L721 was not covered by tests
else:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)

Check warning on line 723 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L723

Added line #L723 was not covered by tests

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)

Check warning on line 726 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L725-L726

Added lines #L725 - L726 were not covered by tests

from paddlenlp_ops import update_inputs_v2

Check warning on line 728 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L728

Added line #L728 was not covered by tests

update_inputs_v2(

Check warning on line 730 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L730

Added line #L730 was not covered by tests
model_kwargs["stop_flags"],
model_kwargs["step_idx"],
model_kwargs["not_need_stop"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["max_dec_len"],
model_kwargs["input_ids"],
model_kwargs["stop_nums"],
next_tokens,
model_kwargs["is_block_step"],
eos_token_id,
model_kwargs["next_tokens"],
)

save_output(

Check warning on line 746 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L746

Added line #L746 was not covered by tests
next_tokens,
model_kwargs["not_need_stop"],
model_kwargs.get("accept_num", None), # only initialized in speculative decoding
self.config.tensor_parallel_rank,
)
return next_tokens

Check warning on line 752 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L752

Added line #L752 was not covered by tests
else:
_, next_tokens = paddle.tensor.top_p_sampling(probs, top_p)
from paddlenlp_ops import (

Check warning on line 754 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L754

Added line #L754 was not covered by tests
speculate_set_value_by_flags_and_idx,
speculate_verify_and_update,
top_p_candidates,
)

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(

Check warning on line 760 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L760

Added line #L760 was not covered by tests
probs, top_p, model_kwargs["output_padding_offset"], self.max_candidate_len, self.max_seq_len
) # [token_num, max_candidate_len]

from paddlenlp_ops import update_inputs_v2
# Speculate Verify And Update
speculate_verify_and_update(

Check warning on line 765 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L765

Added line #L765 was not covered by tests
model_kwargs["accept_tokens"],
model_kwargs["accept_num"],
model_kwargs["step_idx"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["stop_flags"],
model_kwargs["not_need_stop"],
model_kwargs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
model_kwargs["seq_lens_this_time"],
verify_tokens,
verify_scores,
model_kwargs["max_dec_len"],
eos_token_id,
model_kwargs["is_block_step"],
model_kwargs["output_cum_offsets"],
actual_candidate_len,
model_kwargs["actual_draft_token_num"],
top_p,
self.max_seq_len,
self.verify_window,
True, # enable_topp
)

update_inputs_v2(
model_kwargs["stop_flags"],
model_kwargs["step_idx"],
model_kwargs["not_need_stop"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["max_dec_len"],
model_kwargs["input_ids"],
model_kwargs["stop_nums"],
next_tokens,
model_kwargs["is_block_step"],
eos_token_id,
model_kwargs["next_tokens"],
)
from paddlenlp_ops import save_output
save_output(

Check warning on line 791 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L791

Added line #L791 was not covered by tests
model_kwargs["accept_tokens"],
model_kwargs["not_need_stop"],
model_kwargs["accept_num"],
self.config.tensor_parallel_rank,
)

save_output(
next_tokens,
model_kwargs["not_need_stop"],
model_kwargs.get("accept_tokens", None), # only initialized in speculative decoding
self.config.tensor_parallel_rank,
# If seq_lens_decoder is 0 (means stop), accept_num should be set to 0
model_kwargs["accept_num"][model_kwargs["seq_lens_decoder"] == 0] = 0

Check warning on line 799 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L799

Added line #L799 was not covered by tests

# Update pre_ids through accept tokens
speculate_set_value_by_flags_and_idx(

Check warning on line 802 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L802

Added line #L802 was not covered by tests
model_kwargs["pre_ids"],
model_kwargs["accept_tokens"],
model_kwargs["accept_num"],
model_kwargs["stop_flags"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
model_kwargs["seq_lens_decoder"],
model_kwargs["step_idx"],
)

is_speculative_decoding = model_kwargs.get("draft_tokens", None) is not None
if is_speculative_decoding:

Check warning on line 814 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L813-L814

Added lines #L813 - L814 were not covered by tests
# Prepare output padding offset
output_padding_offset, output_cum_offsets = self.get_output_padding_offset(

Check warning on line 816 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L816

Added line #L816 was not covered by tests
model_kwargs["seq_lens_this_time"], model_kwargs["seq_lens_encoder"], model_kwargs["seq_lens_decoder"]
)
return next_tokens
model_kwargs["output_padding_offset"] = output_padding_offset
model_kwargs["output_cum_offsets"] = output_cum_offsets

Check warning on line 820 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L819-L820

Added lines #L819 - L820 were not covered by tests

# encoder
outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed]
Expand Down
Loading