mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support mtp logprob (#4464)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* support mtp logprob * fix unitest
This commit is contained in:
@@ -32,6 +32,8 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores,
|
||||
apply_speculative_penalty_multi_scores,
|
||||
min_p_sampling,
|
||||
speculate_get_target_logits,
|
||||
speculate_insert_first_token,
|
||||
top_k_top_p_sampling,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -455,6 +457,98 @@ class SpeculativeSampler(nn.Layer):
|
||||
"""apply logits processor to sampler"""
|
||||
pass
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> paddle.Tensor:
|
||||
"""compute logprobs"""
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
last_logits = logits
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
|
||||
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
if temp_scaled_logprobs is not None:
|
||||
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||
temperature = sampling_metadata.temperature[:real_bsz]
|
||||
real_bsz_temp_scaled = (
|
||||
real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool")
|
||||
)
|
||||
temperature = temperature.squeeze(1).repeat_interleave(batch_token_num)
|
||||
temp_temperature = paddle.where(
|
||||
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
|
||||
).unsqueeze(1)
|
||||
last_logits = last_logits / temp_temperature
|
||||
|
||||
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
real_token_top_p = (
|
||||
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
|
||||
)
|
||||
top_p_normalized_logprobs = (
|
||||
top_p_normalized_logprobs[:real_bsz]
|
||||
.astype("int32")
|
||||
.squeeze(1)
|
||||
.repeat_interleave(batch_token_num)
|
||||
.astype("bool")
|
||||
.unsqueeze(1)
|
||||
)
|
||||
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
|
||||
if top_p_token_mask.any():
|
||||
probs = F.softmax(last_logits, axis=-1)
|
||||
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
|
||||
top_p_logprob = paddle.log(probs)
|
||||
if top_p_logprob is not None:
|
||||
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
|
||||
return last_logprobs
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
@@ -521,7 +615,56 @@ class SpeculativeSampler(nn.Layer):
|
||||
accept_all_drafts,
|
||||
)
|
||||
|
||||
return None
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
batch_token_num = None
|
||||
if num_logprobs is not None:
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
|
||||
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
|
||||
).squeeze(1)
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
|
||||
).astype("int32")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
target_logtis = paddle.empty(
|
||||
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
|
||||
)
|
||||
speculate_get_target_logits(
|
||||
target_logtis,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
if num_logprobs is not None:
|
||||
token_ids = paddle.concat(
|
||||
[
|
||||
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
|
||||
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
|
||||
]
|
||||
)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=batch_token_num,
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
|
||||
return sampler_output
|
||||
|
||||
|
||||
class MTPSampler(nn.Layer):
|
||||
@@ -556,6 +699,103 @@ class MTPSampler(nn.Layer):
|
||||
"""post process after running"""
|
||||
pass
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> paddle.Tensor:
|
||||
"""compute logprobs"""
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
last_logits = logits
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
if temp_scaled_logprobs is not None:
|
||||
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||
temperature = sampling_metadata.temperature[:real_bsz]
|
||||
real_bsz_temp_scaled = (
|
||||
real_bsz_temp_scaled.astype("int32")
|
||||
.squeeze(1)
|
||||
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
|
||||
.astype("bool")
|
||||
)
|
||||
temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
|
||||
temp_temperature = paddle.where(
|
||||
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
|
||||
).unsqueeze(1)
|
||||
last_logits = last_logits / temp_temperature
|
||||
|
||||
last_logprobs = F.log_softmax(last_logits, axis=-1)
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
real_token_top_p = (
|
||||
sampling_metadata.top_p[:real_bsz]
|
||||
.squeeze(1)
|
||||
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
|
||||
.unsqueeze(1)
|
||||
)
|
||||
top_p_normalized_logprobs = (
|
||||
top_p_normalized_logprobs[:real_bsz]
|
||||
.astype("int32")
|
||||
.squeeze(1)
|
||||
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
|
||||
.astype("bool")
|
||||
.unsqueeze(1)
|
||||
)
|
||||
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
|
||||
|
||||
if top_p_token_mask.any():
|
||||
probs = F.softmax(last_logits, axis=-1)
|
||||
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
|
||||
top_p_logprob = paddle.log(probs)
|
||||
if top_p_logprob is not None:
|
||||
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
|
||||
return last_logprobs
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
token_ids = token_ids.unsqueeze(1)
|
||||
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
@@ -564,6 +804,12 @@ class MTPSampler(nn.Layer):
|
||||
share_inputs: List[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
""" """
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
|
||||
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata)
|
||||
|
||||
logits = apply_speculative_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
logits,
|
||||
@@ -585,4 +831,27 @@ class MTPSampler(nn.Layer):
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
)
|
||||
return next_tokens
|
||||
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
token_ids = paddle.empty(real_token_num, dtype="int64")
|
||||
speculate_insert_first_token(
|
||||
token_ids,
|
||||
share_inputs["accept_tokens"],
|
||||
next_tokens,
|
||||
share_inputs["cu_next_token_offset"],
|
||||
share_inputs["cu_batch_token_offset"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["batch_token_num"][:real_bsz],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
return next_tokens, sampler_output
|
||||
|
||||
Reference in New Issue
Block a user