mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-28 18:51:58 +08:00
remove debug code
This commit is contained in:
@@ -80,30 +80,10 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
|
||||
&src_vec);
|
||||
Store<T, VecSize>(src_vec, &output_data[i]);
|
||||
|
||||
// printf(
|
||||
// "[normal] out_token_id: %d, ori_token_id: %d, input_token_id: %d "
|
||||
// "bias_idx: %d, bid: %d, seq_id: %d\n",
|
||||
// out_token_id,
|
||||
// ori_token_id,
|
||||
// input_token_id,
|
||||
// bias_idx,
|
||||
// bi,
|
||||
// seq_id);
|
||||
|
||||
if (enable_logprob && seq_len_encoder[bi] > 0) {
|
||||
int first_token_seq_id = seq_len_encoder[bi] - 2;
|
||||
const int first_token_id =
|
||||
ori_token_id - cum_offset_bi + first_token_seq_id;
|
||||
// printf(
|
||||
// "[first token] out_token_id: %d, ori_token_id: %d, "
|
||||
// "first_token_id: %d, bias_idx: %d, bid: %d, "
|
||||
// "first_token_seq_id: %d\n",
|
||||
// out_token_id,
|
||||
// ori_token_id,
|
||||
// first_token_id,
|
||||
// bias_idx,
|
||||
// bi,
|
||||
// first_token_seq_id);
|
||||
Load<T, VecSize>(&input_data[first_token_id * dim_embed + bias_idx],
|
||||
&src_vec);
|
||||
Store<T, VecSize>(src_vec, &first_token_out[i]);
|
||||
@@ -153,9 +133,6 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
0,
|
||||
D,
|
||||
tmp_out.place());
|
||||
// printf("token_num: %d, need_delete_token_num: %d\n",
|
||||
// token_num,
|
||||
// need_delete_token_num);
|
||||
} else {
|
||||
out =
|
||||
paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place());
|
||||
@@ -169,10 +146,6 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
printf("elem_nums: %d\n", elem_nums);
|
||||
|
||||
if (output_padding_offset) {
|
||||
// if (first_token_out.is_initialized()) {
|
||||
// printf("first_token_out is initialized, enable_logprob: %d\n",
|
||||
// enable_logprob);
|
||||
// }
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
#define MAX_BSZ 512
|
||||
#define K 20
|
||||
#define MAX_DRAFT_TOKEN_NUM 6
|
||||
#define SPECULATE_GET_WITH_OUTPUT_DEBUG
|
||||
// #define SPECULATE_GET_WITH_OUTPUT_DEBUG
|
||||
|
||||
struct batch_msgdata {
|
||||
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
#define MAX_BSZ 512
|
||||
#define K 20
|
||||
#define MAX_DRAFT_TOKEN_NUM 6
|
||||
#define SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
// #define SPECULATE_SAVE_WITH_OUTPUT_DEBUG
|
||||
|
||||
struct batch_msgdata {
|
||||
int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)];
|
||||
@@ -134,7 +134,6 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids,
|
||||
for (int j = 0; j < cur_token_num; j++) {
|
||||
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)];
|
||||
auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)];
|
||||
std::cout << "token_offset: " << token_offset << std::endl;
|
||||
for (int k = 0; k < K + 1; k++) {
|
||||
if (k == 0) {
|
||||
cur_tokens[k] =
|
||||
|
||||
@@ -403,8 +403,6 @@ class EngineArgs:
|
||||
if self.dynamic_load_weight:
|
||||
self.enable_prefix_caching = False
|
||||
if self.enable_logprob:
|
||||
# if self.speculative_config is not None:
|
||||
# raise NotImplementedError("Logprob does not support speculation_config.")
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if self.splitwise_role != "mixed":
|
||||
|
||||
@@ -290,11 +290,8 @@ class Sampler(nn.Layer):
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
print(f"[Sampler] logprobs: {logprobs}")
|
||||
print(f"[Sampler] token_logprobs: {token_logprobs}")
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
print(f"[Sampler] token_ranks: {token_ranks}")
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
@@ -363,7 +360,6 @@ class Sampler(nn.Layer):
|
||||
sampled_token_ids=next_tokens,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
print(f"[Sampler] sampler_output: {sampler_output}")
|
||||
|
||||
return sampler_output
|
||||
|
||||
@@ -407,11 +403,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
last_logits = logits
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
print(f"[SpeculativeSampler][compute] seq_lens_this_time: {share_inputs['seq_lens_this_time']}")
|
||||
print(f"[SpeculativeSampler][compute] seq_lens_encoder: {share_inputs['seq_lens_encoder']}")
|
||||
batch_token_num = share_inputs["batch_token_num"]
|
||||
|
||||
print(f"[SpeculativeSampler][compute] batch_token_num: {batch_token_num}")
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
if temp_scaled_logprobs is not None:
|
||||
@@ -479,11 +472,8 @@ class SpeculativeSampler(nn.Layer):
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
print(f"[SpeculativeSampler] logprobs: {logprobs}")
|
||||
print(f"[SpeculativeSampler] token_logprobs: {token_logprobs}")
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
print(f"[SpeculativeSampler] token_ranks: {token_ranks}")
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
@@ -534,9 +524,6 @@ class SpeculativeSampler(nn.Layer):
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
print(f"[SpeculativeSampler] verify_tokens: {verify_tokens}")
|
||||
print(f"[SpeculativeSampler] actual_candidate_len: {actual_candidate_len}")
|
||||
|
||||
speculate_verify(
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
@@ -562,10 +549,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
True, # enable_topp
|
||||
self.speculative_benchmark_mode,
|
||||
)
|
||||
print(f"[SpeculativeSampler] accept_num: {share_inputs['accept_num']}")
|
||||
print(f"[SpeculativeSampler] accept_tokens: {share_inputs['accept_tokens']}")
|
||||
|
||||
print(f"[SpeculativeSampler] logits: {logits}")
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
@@ -575,15 +559,13 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
|
||||
).squeeze(1)
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
print(f"[SpeculativeSampler] batch_token_num: {share_inputs['batch_token_num']}")
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])]
|
||||
).astype("int32")
|
||||
print(f"[SpeculativeSampler] ori_cu_batch_token_offset: {ori_cu_batch_token_offset}")
|
||||
print(f"[SpeculativeSampler] cu_batch_token_offset: {cu_batch_token_offset}")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
target_logtis = paddle.empty([share_inputs["accept_num"].sum(), logits.shape[1]], dtype=logits.dtype)
|
||||
speculate_get_target_logits(
|
||||
target_logtis,
|
||||
@@ -594,9 +576,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
print(f"[SpeculativeSampler] target_logtis: {target_logtis}")
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
print(f"[SpeculativeSampler] raw_logprobs: {raw_logprobs}")
|
||||
|
||||
sampler_output = None
|
||||
if num_logprobs is not None:
|
||||
@@ -608,7 +588,6 @@ class SpeculativeSampler(nn.Layer):
|
||||
for i in range(share_inputs["accept_num"].shape[0])
|
||||
]
|
||||
)
|
||||
print(f"[SpeculativeSampler] token_ids: {token_ids}")
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
@@ -617,8 +596,6 @@ class SpeculativeSampler(nn.Layer):
|
||||
token_num_per_batch=batch_token_num,
|
||||
)
|
||||
|
||||
print(f"[SpeculativeSampler] sampler_output: {sampler_output}")
|
||||
|
||||
return sampler_output
|
||||
|
||||
|
||||
@@ -656,7 +633,6 @@ class MTPSampler(nn.Layer):
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
last_logits = logits
|
||||
# print(f"[MTPSampler][compute] real_bsz: {real_bsz}")
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
if temp_scaled_logprobs is not None:
|
||||
@@ -669,17 +645,12 @@ class MTPSampler(nn.Layer):
|
||||
.astype("bool")
|
||||
)
|
||||
temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"])
|
||||
# print(f"[MTPSampler][compute] real_bsz_temp_scaled: {real_bsz_temp_scaled}")
|
||||
# print(f"[MTPSampler][compute] temperature: {temperature}")
|
||||
temp_temperature = paddle.where(
|
||||
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
|
||||
).unsqueeze(1)
|
||||
# print(f"[MTPSampler][compute] temp_temperature: {temp_temperature}")
|
||||
last_logits = last_logits / temp_temperature
|
||||
# print(f"[MTPSampler][compute] last_logits: {last_logits}")
|
||||
|
||||
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||
# print(f"[MTPSampler][compute] last_logits: {last_logits}")
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
@@ -690,7 +661,6 @@ class MTPSampler(nn.Layer):
|
||||
.repeat_interleave(share_inputs["batch_token_num"])
|
||||
.unsqueeze(1)
|
||||
)
|
||||
# print(f"[MTPSampler][compute] real_token_top_p: {real_token_top_p}")
|
||||
top_p_normalized_logprobs = (
|
||||
top_p_normalized_logprobs[:real_bsz]
|
||||
.astype("int32")
|
||||
@@ -699,17 +669,12 @@ class MTPSampler(nn.Layer):
|
||||
.astype("bool")
|
||||
.unsqueeze(1)
|
||||
)
|
||||
# print(f"[MTPSampler][compute] top_p_normalized_logprobs: {top_p_normalized_logprobs}")
|
||||
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
|
||||
# print(f"[MTPSampler][compute] top_p_token_mask: {top_p_token_mask}")
|
||||
|
||||
if top_p_token_mask.any():
|
||||
probs = F.softmax(last_logits, axis=-1)
|
||||
# print(f"[MTPSampler][compute] probs: {probs}")
|
||||
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
|
||||
# print(f"[MTPSampler][compute] probs: {probs}")
|
||||
top_p_logprob = paddle.log(probs)
|
||||
# print(f"[MTPSampler][compute] top_p_logprob: {top_p_logprob}")
|
||||
if top_p_logprob is not None:
|
||||
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
|
||||
return last_logprobs
|
||||
@@ -767,7 +732,6 @@ class MTPSampler(nn.Layer):
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"], sampling_metadata)
|
||||
print(f"[MTPSampler] raw_logprobs: {raw_logprobs}")
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
@@ -803,8 +767,6 @@ class MTPSampler(nn.Layer):
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
print(f"[MTPSampler] token_ids: {token_ids}")
|
||||
print(f"[MTPSampler] total_token_num: {share_inputs['batch_token_num'].sum()}")
|
||||
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
@@ -813,7 +775,5 @@ class MTPSampler(nn.Layer):
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["batch_token_num"],
|
||||
)
|
||||
print(f"[MTPSampler] sampler_output: {sampler_output}")
|
||||
print(f"[MTPSampler] next_tokens: {next_tokens}")
|
||||
|
||||
return next_tokens, sampler_output
|
||||
|
||||
@@ -158,7 +158,6 @@ class TokenProcessor:
|
||||
get_output_ep,
|
||||
get_output_topk,
|
||||
speculate_get_output,
|
||||
speculate_get_output_topk,
|
||||
)
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
@@ -166,24 +165,9 @@ class TokenProcessor:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
if self.use_logprobs:
|
||||
speculate_get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
print(f"[TokenProcessor] output_tokens: {self.output_tokens}")
|
||||
print(f"[TokenProcessor] output_scores: {self.output_scores}")
|
||||
print(f"[TokenProcessor] output_ranks: {self.output_ranks}")
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
|
||||
@@ -626,7 +626,6 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
print(f"[MTPProposer] ******************** substep: {substep} ********************")
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
(
|
||||
@@ -682,19 +681,12 @@ class MTPProposer(Proposer):
|
||||
previous_hidden_states=target_hidden_states,
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
print(f"[MTPProposer] model_output: {model_output}")
|
||||
|
||||
if self.enable_logprob and substep == 0:
|
||||
first_token_hidden_states = paddle.empty(
|
||||
[self.max_num_seqs, self.model_config.hidden_size], dtype=model_output.dtype
|
||||
)
|
||||
|
||||
print(f"[MTPProposer] cu_seqlens_q: {self.model_inputs['cu_seqlens_q']}")
|
||||
print(f"[MTPProposer] seq_lens_this_time: {self.model_inputs['seq_lens_this_time']}")
|
||||
print(f"[MTPProposer] seq_lens_encoder: {self.model_inputs['seq_lens_encoder']}")
|
||||
print(f"[MTPProposer] seq_lens_decoder: {self.model_inputs['seq_lens_decoder']}")
|
||||
print(f"[MTPProposer] output_cum_offsets: {self.model_inputs['output_cum_offsets']}")
|
||||
print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}")
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
self.model_inputs["cu_seqlens_q"],
|
||||
@@ -706,16 +698,11 @@ class MTPProposer(Proposer):
|
||||
first_token_hidden_states if substep == 0 else None,
|
||||
self.enable_logprob if substep == 0 else False,
|
||||
)
|
||||
print(f"[MTPProposer] hidden_states: {hidden_states}")
|
||||
print(f"[MTPProposer] first_token_hidden_states: {first_token_hidden_states}")
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
if self.enable_logprob and substep == 0:
|
||||
first_token_logits = self.model.compute_logits(first_token_hidden_states)
|
||||
print(f"[MTPProposer] logits: {logits}")
|
||||
print(f"[MTPProposer] first_token_logits: {first_token_logits}")
|
||||
print(f"[MTPProposer] output_padding_offset: {self.model_inputs['output_padding_offset']}")
|
||||
|
||||
draft_logits, batch_token_num, cu_batch_token_offset = speculate_get_logits(
|
||||
logits,
|
||||
@@ -727,9 +714,6 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["draft_logits"] = draft_logits
|
||||
self.model_inputs["batch_token_num"] = batch_token_num
|
||||
self.model_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
print(f"[MTPProposer] draft_logits: {draft_logits}")
|
||||
print(f"[MTPProposer] batch_token_num: {batch_token_num}")
|
||||
print(f"[MTPProposer] cu_batch_token_offset: {cu_batch_token_offset}")
|
||||
|
||||
sampled_token_ids, sampler_output = self.sampler(
|
||||
logits,
|
||||
|
||||
Reference in New Issue
Block a user